From a464f39a4ef99491f8def8befa9d0f72865e9941 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Thu, 18 Jul 2024 21:03:26 -0700 Subject: [PATCH 001/376] [XLA:UNSTACKER] Add a function to detect effectively static dynamic-slice instructions inside unrollable loops. PiperOrigin-RevId: 653860014 --- xla/service/while_loop_unroller.cc | 79 ++++++++++++++++++++++ xla/service/while_loop_unroller.h | 9 +++ xla/service/while_loop_unroller_test.cc | 89 +++++++++++++++++++++++++ 3 files changed, 177 insertions(+) diff --git a/xla/service/while_loop_unroller.cc b/xla/service/while_loop_unroller.cc index 90473bae4e7036..534d2c604adb84 100644 --- a/xla/service/while_loop_unroller.cc +++ b/xla/service/while_loop_unroller.cc @@ -369,6 +369,84 @@ bool IsLoopInductionVar(const HloInstruction* instr, } } +// Recursively checks if the given instruction is effectively static by checking +// if it is a constant or a parameter that points to the induction var of the +// given loop config. +bool IsEffectivelyStatic(const HloInstruction* instr, + const WhileLoopConfig& config) { + switch (instr->opcode()) { + case HloOpcode::kConstant: + return true; + case HloOpcode::kParameter: { + if (instr->parent()->IsFusionComputation()) { + HloInstruction* caller_fusion = instr->parent()->FusionInstruction(); + return IsEffectivelyStatic( + caller_fusion->operand(instr->parameter_number()), config); + } + return false; + } + case HloOpcode::kGetTupleElement: { + if (instr->parent() != config.while_instr->while_body()) { + return false; + } + if (!Match(instr, match::GetTupleElement(match::Parameter(), + config.induction_var_idx))) { + return false; + } + return true; + } + default: { + for (int64_t i = 0; i < instr->operand_count(); ++i) { + if (!IsEffectivelyStatic(instr->operand(i), config)) { + return false; + } + } + return true; + } + } +} + +std::optional MatchEffectivelyStaticDynamicSliceInsideLoop( + const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode, + const WhileLoopConfig& config) { + int64_t start_indices_offset = 1; + const HloInstruction* operand = instr->operand(0); + if (operand != input) { + VLOG(3) << "Input of dynamic index instruction is not the given operand."; + return std::nullopt; + } + + int64_t dynamic_index = -1; + for (int64_t start_index = start_indices_offset; + start_index < instr->operand_count(); ++start_index) { + const HloInstruction* index = instr->operand(start_index); + // All constants must be zero in order to slice the entire shape. + if (Match(index, match::ConstantScalar())) { + std::optional offset = + LiteralUtil::LiteralAsScalarInt64(index->literal()); + if (offset.has_value() && offset.value() != 0) { + VLOG(3) << "Constant index " << start_index << " is not zero."; + return std::nullopt; + } + continue; + } + if (IsEffectivelyStatic(index, config)) { + if (dynamic_index != -1) { + VLOG(3) << "Multiple non-constant indices."; + return std::nullopt; + } + dynamic_index = start_index - start_indices_offset; + } + } + + if (dynamic_index == -1) { + VLOG(3) << "No dynamic index found."; + return std::nullopt; + } + + return dynamic_index; +} + std::optional MatchShapeCoveringDynamicIndexInstruction( const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode, const WhileLoopConfig& config) { @@ -530,6 +608,7 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( VLOG(3) << "Loop trip count " << trip_count.value(); WhileLoopConfig config; + config.while_instr = while_op; config.init = LiteralUtil::LiteralAsScalarInt64(std::move(indvar_iter_val)).value(); config.trip_count = trip_count.value(); diff --git a/xla/service/while_loop_unroller.h b/xla/service/while_loop_unroller.h index 77f83a422e3186..1092dd791924cf 100644 --- a/xla/service/while_loop_unroller.h +++ b/xla/service/while_loop_unroller.h @@ -35,6 +35,7 @@ namespace xla { // Config for unrollable while loops. struct WhileLoopConfig { + const HloInstruction* while_instr; // The initial value of the induction variable of the while loop. int64_t init; // The number of iterations the loop executes. @@ -55,6 +56,14 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode, const WhileLoopConfig& config); +// Check if `instr` is a dynamic-slice with the given input and a single dynamic +// start index that is effectively static, i.e., it is an expression that only +// involves the iteration variable of the surrounding loop and some constants, +// if we unroll the surrounding loop. If so, it returns the dynamic index. +std::optional MatchEffectivelyStaticDynamicSliceInsideLoop( + const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode, + const WhileLoopConfig& config); + // This pass unrolls while loops with the given unrolling factor. The value of // unroll_factor = -1 will fully unroll the loop. // diff --git a/xla/service/while_loop_unroller_test.cc b/xla/service/while_loop_unroller_test.cc index 268f82ccf2b3a5..7029172e05f2df 100644 --- a/xla/service/while_loop_unroller_test.cc +++ b/xla/service/while_loop_unroller_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" @@ -1102,5 +1103,93 @@ TEST_F(WhileLoopUnrollerTest, UnrollLoopWithDynamicGte) { } } +TEST_F(WhileLoopUnrollerTest, IsEffectivelyStaticDynamicSlice) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[6,128,128], p1: s32[]) -> s8[128,128] { + %param_0.51117 = s8[6,128,128] parameter(0) + static.p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + %dynamic-slice.static = s8[1,128,128] dynamic-slice(s8[6,128,128] %param_0.51117, static.p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.static) + } + + %fused_computation.slice.2 (param_0.51117: s8[6,128,128], p1: s32[]) -> s8[128,128] { + %param_0.51117 = s8[6,128,128] parameter(0) + dynamic.p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + %dynamic-slice.dynamic = s8[1,128,128] dynamic-slice(s8[6,128,128] %param_0.51117, dynamic.p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.dynamic) + } + + %fused_computation.inner (param_0.34523: bf16[8,128], param_1.30691: s8[6,128,128], p2: s32[], p3: s32[]) -> bf16[8,128] { + %param_0.34523 = bf16[8,128] parameter(0) + %param_1.30691 = s8[6,128,128] parameter(1) + static.p2 = s32[] parameter(2) + %fusion.1 = s8[128,128] fusion(s8[6,128,128] %param_1.30691, static.p2), kind=kLoop, calls=%fused_computation.slice + dynamic.p3 = s32[] parameter(3) + %fusion.2 = s8[128,128] fusion(s8[6,128,128] %param_1.30691, dynamic.p3), kind=kLoop, calls=%fused_computation.slice.2 + out = s8[128,128] add(%fusion.1, %fusion.2) + ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] out), dim_labels=bf_io->bf + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[6,128,128], s32[])) -> (s32[], bf16[8,128], s8[6,128,128], s32[]) { + wide_p = (s32[], bf16[8,128], s8[6,128,128], s32[]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[6,128,128] get-tuple-element(wide_p), index=2 + dynamic.p2 = s32[] get-tuple-element(wide_p), index=3 + one = s32[] constant(1) + inc = s32[] add(i, one) + two = s32[] constant(2) + mult = s32[] multiply(i, two) + fusion.conv = bf16[8,128] fusion(p0, p1, mult, dynamic.p2), kind=kOutput, calls=%fused_computation.inner + ROOT out = (s32[], bf16[8,128], s8[6,128,128], s32[]) tuple(inc, fusion.conv, p1, dynamic.p2) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[6,128,128], s32[])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[6,128,128], s32[]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[6,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + p2 = s32[] parameter(2) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[6,128,128], s32[]) tuple(init, p1, p0, p2) + while.out = (s32[], bf16[8,128], s8[6,128,128], s32[]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[6,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + HloInstruction* loop = + module->entry_computation()->root_instruction()->mutable_operand(0); + std::optional config = + WhileLoopUnroller::IsLoopUnrollable(loop); + EXPECT_TRUE(config.has_value()); + for (HloComputation* comp : module->MakeComputationPostOrder()) { + HloInstruction* static_slice = + comp->GetInstructionWithName("dynamic-slice.static"); + if (static_slice != nullptr) { + auto index = MatchEffectivelyStaticDynamicSliceInsideLoop( + static_slice, static_slice->operand(0), HloOpcode::kDynamicSlice, + *config); + EXPECT_TRUE(index.has_value()); + } + HloInstruction* dynamic_slice = + comp->GetInstructionWithName("dynamic-slice.dynamic"); + if (dynamic_slice != nullptr) { + auto index = MatchEffectivelyStaticDynamicSliceInsideLoop( + dynamic_slice, dynamic_slice->operand(0), HloOpcode::kDynamicSlice, + *config); + EXPECT_FALSE(index.has_value()); + } + } +} + } // namespace } // namespace xla From 15f346bd3b6d22d5f990dc7c0f4a02579a031487 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 18 Jul 2024 21:53:33 -0700 Subject: [PATCH 002/376] [XLA:SPMD] Simplify the chain of sharding instructions. Rewrite the pattern ``` B = sharding-constraint(A, sharding=S1) C = sharding-constraint(B, sharding=S2), which is the only user of B ``` as ``` C = sharding-constraint(A, sharding=S2) ``` The intermediate sharding S1 and tensor B are redundant, which are removed as a pre-processing in sharding propagation. PiperOrigin-RevId: 653870013 --- xla/service/sharding_propagation.cc | 21 +++- xla/service/sharding_propagation_test.cc | 130 +++++++++++++++++++++++ 2 files changed, 150 insertions(+), 1 deletion(-) diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index 96306455462ea4..ca78f979acaaf6 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -1557,11 +1557,30 @@ absl::StatusOr ProcessShardingInstruction( for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { HloInstruction* instruction = *it; if (instruction->IsCustomCall("Sharding")) { - HloSharding original_sharding = instruction->sharding(); TF_RET_CHECK(instruction->has_sharding()) << "Sharding instruction must have a sharding attribute"; + HloSharding original_sharding = instruction->sharding(); VLOG(3) << "ProcessShardingInstruction: " << instruction->ToString(); + // Simplify consecutive Sharding custom-call instructions. If both + // shardings are tiled, we do not simplify the instruction since these + // two shardings can guide the partitioner. An example is + // https://github.com/google/jax/issues/21562. + HloInstruction* operand = instruction->mutable_operand(0); + if (!original_sharding.IsUnknown() && + operand->IsCustomCall("Sharding") && operand->user_count() == 1 && + !(original_sharding.IsTiled() && operand->sharding().IsTiled())) { + operand->set_sharding(original_sharding); + TF_ASSIGN_OR_RETURN( + std::ignore, + computation->ReplaceInstruction( + instruction, operand, /*preserve_sharding=*/false, + /*relay_control_dependency=*/false, + /*remove_unused_operands=*/false)); + changed = true; + continue; + } + std::vector unspec_dims; TF_RETURN_IF_ERROR(sharding_op_util::ParseAttributes( Cast(instruction)->opaque(), diff --git a/xla/service/sharding_propagation_test.cc b/xla/service/sharding_propagation_test.cc index 97599a13ce6132..76977b282c28b5 100644 --- a/xla/service/sharding_propagation_test.cc +++ b/xla/service/sharding_propagation_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/sharding_propagation.h" +#include #include #include #include @@ -12071,5 +12072,134 @@ ENTRY %elementwise { "last_tile_dim_replicate}}")); } +TEST_F(ShardingPropagationTest, RedundantShardingInstruction1) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %main.6 { + %p0 = f32[32,96] parameter(0), sharding={replicated} + %add.0 = f32[32,96] add(%p0, %p0) + %custom-call.0 = f32[32,96] custom-call(%add.0), custom_call_target="Sharding", sharding={replicated} + %custom-call.1 = f32[32,96] custom-call(%custom-call.0), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]} + ROOT %add.1 = f32[32,96] add(%custom-call.1, %custom-call.1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + + int64_t num_copy = 0; + for (const HloInstruction* instruction : + module->entry_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + EXPECT_THAT(instruction, op::Sharding("{devices=[2,2]<=[4]}")); + num_copy++; + } + } + EXPECT_EQ(num_copy, 1); +} + +TEST_F(ShardingPropagationTest, RedundantShardingInstruction2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %main.6 { + %p0 = f32[32,96] parameter(0), sharding={replicated} + %add.0 = f32[32,96] add(%p0, %p0) + %custom-call.0 = f32[32,96] custom-call(%add.0), custom_call_target="Sharding", sharding={maximal device=0} + %custom-call.1 = f32[32,96] custom-call(%custom-call.0), custom_call_target="Sharding", sharding={maximal device=1} + %custom-call.2 = f32[32,96] custom-call(%custom-call.1), custom_call_target="Sharding", sharding={maximal device=2} + ROOT %add.1 = f32[32,96] add(%custom-call.2, %custom-call.2) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + + int64_t num_copy = 0; + for (const HloInstruction* instruction : + module->entry_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + EXPECT_THAT(instruction, op::Sharding("{maximal device=2}")); + num_copy++; + } + } + EXPECT_EQ(num_copy, 1); +} + +TEST_F(ShardingPropagationTest, RedundantShardingInstruction3) { + // This target is similar to RedundantShardingInstruction1, except that + // %custom-call.0 has two users. + const char* const hlo_string = R"( +HloModule module + +ENTRY %main.6 { + %p0 = f32[32,96] parameter(0), sharding={replicated} + %add.0 = f32[32,96] add(%p0, %p0) + %custom-call.0 = f32[32,96] custom-call(%add.0), custom_call_target="Sharding", sharding={replicated} + %custom-call.1 = f32[32,96] custom-call(%custom-call.0), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]} + ROOT %add.1 = f32[32,96] add(%custom-call.0, %custom-call.1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + + int64_t num_copy = 0; + for (const HloInstruction* instruction : + module->entry_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + num_copy++; + } + } + EXPECT_EQ(num_copy, 2); +} + +TEST_F(ShardingPropagationTest, RedundantShardingInstruction4) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %main.6 { + %p0 = f32[32,96] parameter(0), sharding={replicated} + %add.0 = f32[32,96] add(%p0, %p0) + %custom-call.0 = f32[32,96] custom-call(%add.0), custom_call_target="Sharding", sharding={devices=[2,2]<=[2,2]T(1,0)} + %custom-call.1 = f32[32,96] custom-call(%custom-call.0), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]} + ROOT %add.1 = f32[32,96] add(%custom-call.1, %custom-call.1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + + int64_t num_copy = 0; + for (const HloInstruction* instruction : + module->entry_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + num_copy++; + } + } + EXPECT_EQ(num_copy, 2); +} + } // namespace } // namespace xla From 9e4f7c2f609199ed1019dded520e7a0400632be9 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Thu, 18 Jul 2024 23:41:54 -0700 Subject: [PATCH 003/376] Integrate LLVM at llvm/llvm-project@dd7d81ea49bf Updates LLVM usage to match [dd7d81ea49bf](https://github.com/llvm/llvm-project/commit/dd7d81ea49bf) PiperOrigin-RevId: 653892711 --- third_party/llvm/generated.patch | 3735 +---------------- third_party/llvm/workspace.bzl | 4 +- .../tsl/third_party/llvm/generated.patch | 3735 +---------------- .../tsl/third_party/llvm/workspace.bzl | 4 +- 4 files changed, 320 insertions(+), 7158 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index d6f26a04468fd2..ed3d58f027f90b 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -11,17 +11,6 @@ diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/Release C/C++ Language Potentially Breaking Changes ------------------------------------------- -@@ -313,10 +311,6 @@ - - Clang now considers ``noexcept(typeid(expr))`` more carefully, instead of always assuming that ``std::bad_typeid`` can be thrown. - (`CWG2191: Incorrect result for noexcept(typeid(v)) `_). - --- Clang now correctly implements lookup for the terminal name of a member-qualified nested-name-specifier. -- (`CWG1835: Dependent member lookup before < `_). -- The warning can be disabled via `-Wno-missing-dependent-template-keyword`. -- - C Language Changes - ------------------ - diff -ruN --strip-trailing-cr a/clang/docs/tools/clang-formatted-files.txt b/clang/docs/tools/clang-formatted-files.txt --- a/clang/docs/tools/clang-formatted-files.txt +++ b/clang/docs/tools/clang-formatted-files.txt @@ -33,642 +22,6 @@ diff -ruN --strip-trailing-cr a/clang/docs/tools/clang-formatted-files.txt b/cla clang/lib/Basic/Targets/M68k.h clang/lib/Basic/Targets/MSP430.h clang/lib/Basic/Targets/NVPTX.cpp -diff -ruN --strip-trailing-cr a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h ---- a/clang/include/clang/AST/ExprCXX.h -+++ b/clang/include/clang/AST/ExprCXX.h -@@ -3676,9 +3676,9 @@ - /// an implicit access if a qualifier is provided. - class CXXDependentScopeMemberExpr final - : public Expr, -- private llvm::TrailingObjects< -- CXXDependentScopeMemberExpr, NestedNameSpecifierLoc, DeclAccessPair, -- ASTTemplateKWAndArgsInfo, TemplateArgumentLoc> { -+ private llvm::TrailingObjects { - friend class ASTStmtReader; - friend class ASTStmtWriter; - friend TrailingObjects; -@@ -3691,15 +3691,17 @@ - /// implicit accesses. - QualType BaseType; - -+ /// The nested-name-specifier that precedes the member name, if any. -+ /// FIXME: This could be in principle store as a trailing object. -+ /// However the performance impact of doing so should be investigated first. -+ NestedNameSpecifierLoc QualifierLoc; -+ - /// The member to which this member expression refers, which - /// can be name, overloaded operator, or destructor. - /// - /// FIXME: could also be a template-id - DeclarationNameInfo MemberNameInfo; - -- /// The location of the '->' or '.' operator. -- SourceLocation OperatorLoc; -- - // CXXDependentScopeMemberExpr is followed by several trailing objects, - // some of which optional. They are in order: - // -@@ -3719,16 +3721,8 @@ - return CXXDependentScopeMemberExprBits.HasTemplateKWAndArgsInfo; - } - -- unsigned getNumUnqualifiedLookups() const { -- return CXXDependentScopeMemberExprBits.NumUnqualifiedLookups; -- } -- -- unsigned numTrailingObjects(OverloadToken) const { -- return hasQualifier(); -- } -- -- unsigned numTrailingObjects(OverloadToken) const { -- return getNumUnqualifiedLookups(); -+ bool hasFirstQualifierFoundInScope() const { -+ return CXXDependentScopeMemberExprBits.HasFirstQualifierFoundInScope; - } - - unsigned numTrailingObjects(OverloadToken) const { -@@ -3739,32 +3733,33 @@ - return getNumTemplateArgs(); - } - -+ unsigned numTrailingObjects(OverloadToken) const { -+ return hasFirstQualifierFoundInScope(); -+ } -+ - CXXDependentScopeMemberExpr(const ASTContext &Ctx, Expr *Base, - QualType BaseType, bool IsArrow, - SourceLocation OperatorLoc, - NestedNameSpecifierLoc QualifierLoc, - SourceLocation TemplateKWLoc, -- ArrayRef UnqualifiedLookups, -+ NamedDecl *FirstQualifierFoundInScope, - DeclarationNameInfo MemberNameInfo, - const TemplateArgumentListInfo *TemplateArgs); - -- CXXDependentScopeMemberExpr(EmptyShell Empty, bool HasQualifier, -- unsigned NumUnqualifiedLookups, -- bool HasTemplateKWAndArgsInfo); -+ CXXDependentScopeMemberExpr(EmptyShell Empty, bool HasTemplateKWAndArgsInfo, -+ bool HasFirstQualifierFoundInScope); - - public: - static CXXDependentScopeMemberExpr * - Create(const ASTContext &Ctx, Expr *Base, QualType BaseType, bool IsArrow, - SourceLocation OperatorLoc, NestedNameSpecifierLoc QualifierLoc, -- SourceLocation TemplateKWLoc, -- ArrayRef UnqualifiedLookups, -+ SourceLocation TemplateKWLoc, NamedDecl *FirstQualifierFoundInScope, - DeclarationNameInfo MemberNameInfo, - const TemplateArgumentListInfo *TemplateArgs); - - static CXXDependentScopeMemberExpr * -- CreateEmpty(const ASTContext &Ctx, bool HasQualifier, -- unsigned NumUnqualifiedLookups, bool HasTemplateKWAndArgsInfo, -- unsigned NumTemplateArgs); -+ CreateEmpty(const ASTContext &Ctx, bool HasTemplateKWAndArgsInfo, -+ unsigned NumTemplateArgs, bool HasFirstQualifierFoundInScope); - - /// True if this is an implicit access, i.e. one in which the - /// member being accessed was not written in the source. The source -@@ -3789,35 +3784,34 @@ - bool isArrow() const { return CXXDependentScopeMemberExprBits.IsArrow; } - - /// Retrieve the location of the '->' or '.' operator. -- SourceLocation getOperatorLoc() const { return OperatorLoc; } -- -- /// Determines whether this member expression had a nested-name-specifier -- /// prior to the name of the member, e.g., x->Base::foo. -- bool hasQualifier() const { -- return CXXDependentScopeMemberExprBits.HasQualifier; -- } -- -- /// If the member name was qualified, retrieves the nested-name-specifier -- /// that precedes the member name, with source-location information. -- NestedNameSpecifierLoc getQualifierLoc() const { -- if (!hasQualifier()) -- return NestedNameSpecifierLoc(); -- return *getTrailingObjects(); -+ SourceLocation getOperatorLoc() const { -+ return CXXDependentScopeMemberExprBits.OperatorLoc; - } - -- /// If the member name was qualified, retrieves the -- /// nested-name-specifier that precedes the member name. Otherwise, returns -- /// NULL. -+ /// Retrieve the nested-name-specifier that qualifies the member name. - NestedNameSpecifier *getQualifier() const { -- return getQualifierLoc().getNestedNameSpecifier(); -+ return QualifierLoc.getNestedNameSpecifier(); - } - -- /// Retrieve the declarations found by unqualified lookup for the first -- /// component name of the nested-name-specifier, if any. -- ArrayRef unqualified_lookups() const { -- if (!getNumUnqualifiedLookups()) -- return std::nullopt; -- return {getTrailingObjects(), getNumUnqualifiedLookups()}; -+ /// Retrieve the nested-name-specifier that qualifies the member -+ /// name, with source location information. -+ NestedNameSpecifierLoc getQualifierLoc() const { return QualifierLoc; } -+ -+ /// Retrieve the first part of the nested-name-specifier that was -+ /// found in the scope of the member access expression when the member access -+ /// was initially parsed. -+ /// -+ /// This function only returns a useful result when member access expression -+ /// uses a qualified member name, e.g., "x.Base::f". Here, the declaration -+ /// returned by this function describes what was found by unqualified name -+ /// lookup for the identifier "Base" within the scope of the member access -+ /// expression itself. At template instantiation time, this information is -+ /// combined with the results of name lookup into the type of the object -+ /// expression itself (the class type of x). -+ NamedDecl *getFirstQualifierFoundInScope() const { -+ if (!hasFirstQualifierFoundInScope()) -+ return nullptr; -+ return *getTrailingObjects(); - } - - /// Retrieve the name of the member that this expression refers to. -diff -ruN --strip-trailing-cr a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h ---- a/clang/include/clang/AST/Stmt.h -+++ b/clang/include/clang/AST/Stmt.h -@@ -1020,19 +1020,18 @@ - LLVM_PREFERRED_TYPE(bool) - unsigned IsArrow : 1; - -- /// True if this member expression used a nested-name-specifier to -- /// refer to the member, e.g., "x->Base::f". -- LLVM_PREFERRED_TYPE(bool) -- unsigned HasQualifier : 1; -- - /// Whether this member expression has info for explicit template - /// keyword and arguments. - LLVM_PREFERRED_TYPE(bool) - unsigned HasTemplateKWAndArgsInfo : 1; - -- /// Number of declarations found by unqualified lookup for the -- /// first component name of the nested-name-specifier. -- unsigned NumUnqualifiedLookups; -+ /// See getFirstQualifierFoundInScope() and the comment listing -+ /// the trailing objects. -+ LLVM_PREFERRED_TYPE(bool) -+ unsigned HasFirstQualifierFoundInScope : 1; -+ -+ /// The location of the '->' or '.' operator. -+ SourceLocation OperatorLoc; - }; - - class OverloadExprBitfields { -diff -ruN --strip-trailing-cr a/clang/include/clang/AST/UnresolvedSet.h b/clang/include/clang/AST/UnresolvedSet.h ---- a/clang/include/clang/AST/UnresolvedSet.h -+++ b/clang/include/clang/AST/UnresolvedSet.h -@@ -97,10 +97,6 @@ - decls().push_back(DeclAccessPair::make(D, AS)); - } - -- void addAllDecls(ArrayRef Other) { -- append(iterator(Other.begin()), iterator(Other.end())); -- } -- - /// Replaces the given declaration with the new one, once. - /// - /// \return true if the set changed -diff -ruN --strip-trailing-cr a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td ---- a/clang/include/clang/Basic/DiagnosticParseKinds.td -+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td -@@ -895,9 +895,10 @@ - "keyword">, InGroup>, - DefaultError; - --def ext_missing_dependent_template_keyword : ExtWarn< -- "use 'template' keyword to treat '%0' as a dependent template name">, -- InGroup>; -+def err_missing_dependent_template_keyword : Error< -+ "use 'template' keyword to treat '%0' as a dependent template name">; -+def warn_missing_dependent_template_keyword : ExtWarn< -+ "use 'template' keyword to treat '%0' as a dependent template name">; - - def ext_extern_template : Extension< - "extern templates are a C++11 extension">, InGroup; -diff -ruN --strip-trailing-cr a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h ---- a/clang/include/clang/Parse/Parser.h -+++ b/clang/include/clang/Parse/Parser.h -@@ -3368,11 +3368,15 @@ - BaseResult ParseBaseSpecifier(Decl *ClassDecl); - AccessSpecifier getAccessSpecifierIfPresent() const; - -- bool ParseUnqualifiedIdTemplateId( -- CXXScopeSpec &SS, ParsedType ObjectType, bool ObjectHadErrors, -- SourceLocation TemplateKWLoc, SourceLocation TildeLoc, -- IdentifierInfo *Name, SourceLocation NameLoc, bool EnteringContext, -- UnqualifiedId &Id, bool AssumeTemplateId); -+ bool ParseUnqualifiedIdTemplateId(CXXScopeSpec &SS, -+ ParsedType ObjectType, -+ bool ObjectHadErrors, -+ SourceLocation TemplateKWLoc, -+ IdentifierInfo *Name, -+ SourceLocation NameLoc, -+ bool EnteringContext, -+ UnqualifiedId &Id, -+ bool AssumeTemplateId); - bool ParseUnqualifiedIdOperator(CXXScopeSpec &SS, bool EnteringContext, - ParsedType ObjectType, - UnqualifiedId &Result); -diff -ruN --strip-trailing-cr a/clang/include/clang/Sema/DeclSpec.h b/clang/include/clang/Sema/DeclSpec.h ---- a/clang/include/clang/Sema/DeclSpec.h -+++ b/clang/include/clang/Sema/DeclSpec.h -@@ -75,7 +75,6 @@ - SourceRange Range; - NestedNameSpecifierLocBuilder Builder; - ArrayRef TemplateParamLists; -- ArrayRef UnqualifiedLookups; - - public: - SourceRange getRange() const { return Range; } -@@ -92,13 +91,6 @@ - return TemplateParamLists; - } - -- void setUnqualifiedLookups(ArrayRef Found) { -- UnqualifiedLookups = Found; -- } -- ArrayRef getUnqualifiedLookups() const { -- return UnqualifiedLookups; -- } -- - /// Retrieve the representation of the nested-name-specifier. - NestedNameSpecifier *getScopeRep() const { - return Builder.getRepresentation(); -diff -ruN --strip-trailing-cr a/clang/include/clang/Sema/Lookup.h b/clang/include/clang/Sema/Lookup.h ---- a/clang/include/clang/Sema/Lookup.h -+++ b/clang/include/clang/Sema/Lookup.h -@@ -483,15 +483,11 @@ - ResultKind = Found; - } - -- void addAllDecls(ArrayRef Other) { -- Decls.addAllDecls(Other); -- ResultKind = Found; -- } -- - /// Add all the declarations from another set of lookup - /// results. - void addAllDecls(const LookupResult &Other) { -- addAllDecls(Other.Decls.pairs()); -+ Decls.append(Other.Decls.begin(), Other.Decls.end()); -+ ResultKind = Found; - } - - /// Determine whether no result was found because we could not -diff -ruN --strip-trailing-cr a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h ---- a/clang/include/clang/Sema/Sema.h -+++ b/clang/include/clang/Sema/Sema.h -@@ -2802,8 +2802,7 @@ - /// (e.g., Base::), perform name lookup for that identifier as a - /// nested-name-specifier within the given scope, and return the result of - /// that name lookup. -- bool LookupFirstQualifierInScope(Scope *S, NestedNameSpecifier *NNS, -- UnresolvedSetImpl &R); -+ NamedDecl *FindFirstQualifierInScope(Scope *S, NestedNameSpecifier *NNS); - - /// Keeps information about an identifier in a nested-name-spec. - /// -@@ -2843,6 +2842,9 @@ - /// \param EnteringContext If true, enter the context specified by the - /// nested-name-specifier. - /// \param SS Optional nested name specifier preceding the identifier. -+ /// \param ScopeLookupResult Provides the result of name lookup within the -+ /// scope of the nested-name-specifier that was computed at template -+ /// definition time. - /// \param ErrorRecoveryLookup Specifies if the method is called to improve - /// error recovery and what kind of recovery is performed. - /// \param IsCorrectedToColon If not null, suggestion of replace '::' -> ':' -@@ -2851,6 +2853,11 @@ - /// not '::'. - /// \param OnlyNamespace If true, only considers namespaces in lookup. - /// -+ /// This routine differs only slightly from ActOnCXXNestedNameSpecifier, in -+ /// that it contains an extra parameter \p ScopeLookupResult, which provides -+ /// the result of name lookup within the scope of the nested-name-specifier -+ /// that was computed at template definition time. -+ /// - /// If ErrorRecoveryLookup is true, then this call is used to improve error - /// recovery. This means that it should not emit diagnostics, it should - /// just return true on failure. It also means it should only return a valid -@@ -2859,6 +2866,7 @@ - /// specifier. - bool BuildCXXNestedNameSpecifier(Scope *S, NestedNameSpecInfo &IdInfo, - bool EnteringContext, CXXScopeSpec &SS, -+ NamedDecl *ScopeLookupResult, - bool ErrorRecoveryLookup, - bool *IsCorrectedToColon = nullptr, - bool OnlyNamespace = false); -@@ -8558,12 +8566,11 @@ - const TemplateArgumentListInfo *TemplateArgs, - bool IsDefiniteInstance, const Scope *S); - -- ExprResult -- ActOnDependentMemberExpr(Expr *Base, QualType BaseType, bool IsArrow, -- SourceLocation OpLoc, const CXXScopeSpec &SS, -- SourceLocation TemplateKWLoc, -- const DeclarationNameInfo &NameInfo, -- const TemplateArgumentListInfo *TemplateArgs); -+ ExprResult ActOnDependentMemberExpr( -+ Expr *Base, QualType BaseType, bool IsArrow, SourceLocation OpLoc, -+ const CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, const DeclarationNameInfo &NameInfo, -+ const TemplateArgumentListInfo *TemplateArgs); - - /// The main callback when the parser finds something like - /// expression . [nested-name-specifier] identifier -@@ -8619,14 +8626,15 @@ - ExprResult BuildMemberReferenceExpr( - Expr *Base, QualType BaseType, SourceLocation OpLoc, bool IsArrow, - CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -- const DeclarationNameInfo &NameInfo, -+ NamedDecl *FirstQualifierInScope, const DeclarationNameInfo &NameInfo, - const TemplateArgumentListInfo *TemplateArgs, const Scope *S, - ActOnMemberAccessExtraArgs *ExtraArgs = nullptr); - - ExprResult - BuildMemberReferenceExpr(Expr *Base, QualType BaseType, SourceLocation OpLoc, - bool IsArrow, const CXXScopeSpec &SS, -- SourceLocation TemplateKWLoc, LookupResult &R, -+ SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, LookupResult &R, - const TemplateArgumentListInfo *TemplateArgs, - const Scope *S, bool SuppressQualifierCheck = false, - ActOnMemberAccessExtraArgs *ExtraArgs = nullptr); -@@ -11114,14 +11122,15 @@ - QualType ObjectType, bool EnteringContext, - RequiredTemplateKind RequiredTemplate = SourceLocation(), - AssumedTemplateKind *ATK = nullptr, -- bool AllowTypoCorrection = true, bool MayBeNNS = false); -+ bool AllowTypoCorrection = true); - -- TemplateNameKind -- isTemplateName(Scope *S, CXXScopeSpec &SS, bool hasTemplateKeyword, -- const UnqualifiedId &Name, ParsedType ObjectType, -- bool EnteringContext, TemplateTy &Template, -- bool &MemberOfUnknownSpecialization, -- bool Disambiguation = false, bool MayBeNNS = false); -+ TemplateNameKind isTemplateName(Scope *S, CXXScopeSpec &SS, -+ bool hasTemplateKeyword, -+ const UnqualifiedId &Name, -+ ParsedType ObjectType, bool EnteringContext, -+ TemplateTy &Template, -+ bool &MemberOfUnknownSpecialization, -+ bool Disambiguation = false); - - /// Try to resolve an undeclared template name as a type template. - /// -@@ -11450,11 +11459,12 @@ - /// For example, given "x.MetaFun::template apply", the scope specifier - /// \p SS will be "MetaFun::", \p TemplateKWLoc contains the location - /// of the "template" keyword, and "apply" is the \p Name. -- TemplateNameKind -- ActOnTemplateName(Scope *S, CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -- const UnqualifiedId &Name, ParsedType ObjectType, -- bool EnteringContext, TemplateTy &Template, -- bool AllowInjectedClassName = false, bool MayBeNNS = false); -+ TemplateNameKind ActOnTemplateName(Scope *S, CXXScopeSpec &SS, -+ SourceLocation TemplateKWLoc, -+ const UnqualifiedId &Name, -+ ParsedType ObjectType, -+ bool EnteringContext, TemplateTy &Template, -+ bool AllowInjectedClassName = false); - - DeclResult ActOnClassTemplateSpecialization( - Scope *S, unsigned TagSpec, TagUseKind TUK, SourceLocation KWLoc, -diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp ---- a/clang/lib/AST/ASTImporter.cpp -+++ b/clang/lib/AST/ASTImporter.cpp -@@ -8439,14 +8439,8 @@ - auto ToOperatorLoc = importChecked(Err, E->getOperatorLoc()); - auto ToQualifierLoc = importChecked(Err, E->getQualifierLoc()); - auto ToTemplateKeywordLoc = importChecked(Err, E->getTemplateKeywordLoc()); -- -- UnresolvedSet<8> ToUnqualifiedLookups; -- for (auto D : E->unqualified_lookups()) -- if (auto ToDOrErr = import(D.getDecl())) -- ToUnqualifiedLookups.addDecl(*ToDOrErr); -- else -- return ToDOrErr.takeError(); -- -+ auto ToFirstQualifierFoundInScope = -+ importChecked(Err, E->getFirstQualifierFoundInScope()); - if (Err) - return std::move(Err); - -@@ -8480,7 +8474,7 @@ - - return CXXDependentScopeMemberExpr::Create( - Importer.getToContext(), ToBase, ToType, E->isArrow(), ToOperatorLoc, -- ToQualifierLoc, ToTemplateKeywordLoc, ToUnqualifiedLookups.pairs(), -+ ToQualifierLoc, ToTemplateKeywordLoc, ToFirstQualifierFoundInScope, - ToMemberNameInfo, ResInfo); - } - -diff -ruN --strip-trailing-cr a/clang/lib/AST/ExprCXX.cpp b/clang/lib/AST/ExprCXX.cpp ---- a/clang/lib/AST/ExprCXX.cpp -+++ b/clang/lib/AST/ExprCXX.cpp -@@ -1489,27 +1489,19 @@ - CXXDependentScopeMemberExpr::CXXDependentScopeMemberExpr( - const ASTContext &Ctx, Expr *Base, QualType BaseType, bool IsArrow, - SourceLocation OperatorLoc, NestedNameSpecifierLoc QualifierLoc, -- SourceLocation TemplateKWLoc, ArrayRef UnqualifiedLookups, -+ SourceLocation TemplateKWLoc, NamedDecl *FirstQualifierFoundInScope, - DeclarationNameInfo MemberNameInfo, - const TemplateArgumentListInfo *TemplateArgs) - : Expr(CXXDependentScopeMemberExprClass, Ctx.DependentTy, VK_LValue, - OK_Ordinary), -- Base(Base), BaseType(BaseType), MemberNameInfo(MemberNameInfo), -- OperatorLoc(OperatorLoc) { -+ Base(Base), BaseType(BaseType), QualifierLoc(QualifierLoc), -+ MemberNameInfo(MemberNameInfo) { - CXXDependentScopeMemberExprBits.IsArrow = IsArrow; -- CXXDependentScopeMemberExprBits.HasQualifier = QualifierLoc.hasQualifier(); -- CXXDependentScopeMemberExprBits.NumUnqualifiedLookups = -- UnqualifiedLookups.size(); - CXXDependentScopeMemberExprBits.HasTemplateKWAndArgsInfo = - (TemplateArgs != nullptr) || TemplateKWLoc.isValid(); -- -- if (hasQualifier()) -- new (getTrailingObjects()) -- NestedNameSpecifierLoc(QualifierLoc); -- -- std::uninitialized_copy_n(UnqualifiedLookups.data(), -- UnqualifiedLookups.size(), -- getTrailingObjects()); -+ CXXDependentScopeMemberExprBits.HasFirstQualifierFoundInScope = -+ FirstQualifierFoundInScope != nullptr; -+ CXXDependentScopeMemberExprBits.OperatorLoc = OperatorLoc; - - if (TemplateArgs) { - auto Deps = TemplateArgumentDependence::None; -@@ -1521,59 +1513,54 @@ - TemplateKWLoc); - } - -+ if (hasFirstQualifierFoundInScope()) -+ *getTrailingObjects() = FirstQualifierFoundInScope; - setDependence(computeDependence(this)); - } - - CXXDependentScopeMemberExpr::CXXDependentScopeMemberExpr( -- EmptyShell Empty, bool HasQualifier, unsigned NumUnqualifiedLookups, -- bool HasTemplateKWAndArgsInfo) -+ EmptyShell Empty, bool HasTemplateKWAndArgsInfo, -+ bool HasFirstQualifierFoundInScope) - : Expr(CXXDependentScopeMemberExprClass, Empty) { -- CXXDependentScopeMemberExprBits.HasQualifier = HasQualifier; -- CXXDependentScopeMemberExprBits.NumUnqualifiedLookups = NumUnqualifiedLookups; - CXXDependentScopeMemberExprBits.HasTemplateKWAndArgsInfo = - HasTemplateKWAndArgsInfo; -+ CXXDependentScopeMemberExprBits.HasFirstQualifierFoundInScope = -+ HasFirstQualifierFoundInScope; - } - - CXXDependentScopeMemberExpr *CXXDependentScopeMemberExpr::Create( - const ASTContext &Ctx, Expr *Base, QualType BaseType, bool IsArrow, - SourceLocation OperatorLoc, NestedNameSpecifierLoc QualifierLoc, -- SourceLocation TemplateKWLoc, ArrayRef UnqualifiedLookups, -+ SourceLocation TemplateKWLoc, NamedDecl *FirstQualifierFoundInScope, - DeclarationNameInfo MemberNameInfo, - const TemplateArgumentListInfo *TemplateArgs) { -- bool HasQualifier = QualifierLoc.hasQualifier(); -- unsigned NumUnqualifiedLookups = UnqualifiedLookups.size(); -- assert(!NumUnqualifiedLookups || HasQualifier); - bool HasTemplateKWAndArgsInfo = - (TemplateArgs != nullptr) || TemplateKWLoc.isValid(); - unsigned NumTemplateArgs = TemplateArgs ? TemplateArgs->size() : 0; -- unsigned Size = -- totalSizeToAlloc( -- HasQualifier, NumUnqualifiedLookups, HasTemplateKWAndArgsInfo, -- NumTemplateArgs); -+ bool HasFirstQualifierFoundInScope = FirstQualifierFoundInScope != nullptr; -+ -+ unsigned Size = totalSizeToAlloc( -+ HasTemplateKWAndArgsInfo, NumTemplateArgs, HasFirstQualifierFoundInScope); - - void *Mem = Ctx.Allocate(Size, alignof(CXXDependentScopeMemberExpr)); - return new (Mem) CXXDependentScopeMemberExpr( - Ctx, Base, BaseType, IsArrow, OperatorLoc, QualifierLoc, TemplateKWLoc, -- UnqualifiedLookups, MemberNameInfo, TemplateArgs); -+ FirstQualifierFoundInScope, MemberNameInfo, TemplateArgs); - } - - CXXDependentScopeMemberExpr *CXXDependentScopeMemberExpr::CreateEmpty( -- const ASTContext &Ctx, bool HasQualifier, unsigned NumUnqualifiedLookups, -- bool HasTemplateKWAndArgsInfo, unsigned NumTemplateArgs) { -- assert(!NumTemplateArgs || HasTemplateKWAndArgsInfo); -- assert(!NumUnqualifiedLookups || HasQualifier); -- -- unsigned Size = -- totalSizeToAlloc( -- HasQualifier, NumUnqualifiedLookups, HasTemplateKWAndArgsInfo, -- NumTemplateArgs); -+ const ASTContext &Ctx, bool HasTemplateKWAndArgsInfo, -+ unsigned NumTemplateArgs, bool HasFirstQualifierFoundInScope) { -+ assert(NumTemplateArgs == 0 || HasTemplateKWAndArgsInfo); -+ -+ unsigned Size = totalSizeToAlloc( -+ HasTemplateKWAndArgsInfo, NumTemplateArgs, HasFirstQualifierFoundInScope); - - void *Mem = Ctx.Allocate(Size, alignof(CXXDependentScopeMemberExpr)); -- return new (Mem) CXXDependentScopeMemberExpr(EmptyShell(), HasQualifier, -- NumUnqualifiedLookups, -- HasTemplateKWAndArgsInfo); -+ return new (Mem) CXXDependentScopeMemberExpr( -+ EmptyShell(), HasTemplateKWAndArgsInfo, HasFirstQualifierFoundInScope); - } - - CXXThisExpr *CXXThisExpr::Create(const ASTContext &Ctx, SourceLocation L, -diff -ruN --strip-trailing-cr a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp ---- a/clang/lib/AST/ItaniumMangle.cpp -+++ b/clang/lib/AST/ItaniumMangle.cpp -@@ -594,10 +594,11 @@ - void mangleMemberExprBase(const Expr *base, bool isArrow); - void mangleMemberExpr(const Expr *base, bool isArrow, - NestedNameSpecifier *qualifier, -- ArrayRef UnqualifiedLookups, -+ NamedDecl *firstQualifierLookup, - DeclarationName name, - const TemplateArgumentLoc *TemplateArgs, -- unsigned NumTemplateArgs, unsigned knownArity); -+ unsigned NumTemplateArgs, -+ unsigned knownArity); - void mangleCastExpression(const Expr *E, StringRef CastEncoding); - void mangleInitListElements(const InitListExpr *InitList); - void mangleRequirement(SourceLocation RequiresExprLoc, -@@ -4495,11 +4496,14 @@ - } - - /// Mangles a member expression. --void CXXNameMangler::mangleMemberExpr( -- const Expr *base, bool isArrow, NestedNameSpecifier *qualifier, -- ArrayRef UnqualifiedLookups, DeclarationName member, -- const TemplateArgumentLoc *TemplateArgs, unsigned NumTemplateArgs, -- unsigned arity) { -+void CXXNameMangler::mangleMemberExpr(const Expr *base, -+ bool isArrow, -+ NestedNameSpecifier *qualifier, -+ NamedDecl *firstQualifierLookup, -+ DeclarationName member, -+ const TemplateArgumentLoc *TemplateArgs, -+ unsigned NumTemplateArgs, -+ unsigned arity) { - // ::= dt - // ::= pt - if (base) -@@ -4981,9 +4985,11 @@ - case Expr::MemberExprClass: { - NotPrimaryExpr(); - const MemberExpr *ME = cast(E); -- mangleMemberExpr(ME->getBase(), ME->isArrow(), ME->getQualifier(), -- std::nullopt, ME->getMemberDecl()->getDeclName(), -- ME->getTemplateArgs(), ME->getNumTemplateArgs(), Arity); -+ mangleMemberExpr(ME->getBase(), ME->isArrow(), -+ ME->getQualifier(), nullptr, -+ ME->getMemberDecl()->getDeclName(), -+ ME->getTemplateArgs(), ME->getNumTemplateArgs(), -+ Arity); - break; - } - -@@ -4991,9 +4997,10 @@ - NotPrimaryExpr(); - const UnresolvedMemberExpr *ME = cast(E); - mangleMemberExpr(ME->isImplicitAccess() ? nullptr : ME->getBase(), -- ME->isArrow(), ME->getQualifier(), std::nullopt, -- ME->getMemberName(), ME->getTemplateArgs(), -- ME->getNumTemplateArgs(), Arity); -+ ME->isArrow(), ME->getQualifier(), nullptr, -+ ME->getMemberName(), -+ ME->getTemplateArgs(), ME->getNumTemplateArgs(), -+ Arity); - break; - } - -@@ -5003,8 +5010,10 @@ - = cast(E); - mangleMemberExpr(ME->isImplicitAccess() ? nullptr : ME->getBase(), - ME->isArrow(), ME->getQualifier(), -- ME->unqualified_lookups(), ME->getMember(), -- ME->getTemplateArgs(), ME->getNumTemplateArgs(), Arity); -+ ME->getFirstQualifierFoundInScope(), -+ ME->getMember(), -+ ME->getTemplateArgs(), ME->getNumTemplateArgs(), -+ Arity); - break; - } - diff -ruN --strip-trailing-cr a/clang/lib/Basic/CMakeLists.txt b/clang/lib/Basic/CMakeLists.txt --- a/clang/lib/Basic/CMakeLists.txt +++ b/clang/lib/Basic/CMakeLists.txt @@ -779,2042 +132,96 @@ diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/Le64.h b/clang/lib/Basic + bool hasProtectedVisibility() const override { return false; } +}; + -+} // namespace targets -+} // namespace clang -+#endif // LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/OSTargets.h b/clang/lib/Basic/Targets/OSTargets.h ---- a/clang/lib/Basic/Targets/OSTargets.h -+++ b/clang/lib/Basic/Targets/OSTargets.h -@@ -841,6 +841,9 @@ - "i64:64-i128:128-n8:16:32:64-S128"); - } else if (Triple.getArch() == llvm::Triple::mipsel) { - // Handled on mips' setDataLayout. -+ } else { -+ assert(Triple.getArch() == llvm::Triple::le32); -+ this->resetDataLayout("e-p:32:32-i64:64"); - } - } - }; -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets.cpp b/clang/lib/Basic/Targets.cpp ---- a/clang/lib/Basic/Targets.cpp -+++ b/clang/lib/Basic/Targets.cpp -@@ -23,6 +23,7 @@ - #include "Targets/DirectX.h" - #include "Targets/Hexagon.h" - #include "Targets/Lanai.h" -+#include "Targets/Le64.h" - #include "Targets/LoongArch.h" - #include "Targets/M68k.h" - #include "Targets/MSP430.h" -@@ -343,6 +344,17 @@ - return std::make_unique(Triple, Opts); - } - -+ case llvm::Triple::le32: -+ switch (os) { -+ case llvm::Triple::NaCl: -+ return std::make_unique>(Triple, Opts); -+ default: -+ return nullptr; -+ } -+ -+ case llvm::Triple::le64: -+ return std::make_unique(Triple, Opts); -+ - case llvm::Triple::ppc: - switch (os) { - case llvm::Triple::Linux: -diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp ---- a/clang/lib/CodeGen/CodeGenModule.cpp -+++ b/clang/lib/CodeGen/CodeGenModule.cpp -@@ -116,6 +116,8 @@ - default: - return createDefaultTargetCodeGenInfo(CGM); - -+ case llvm::Triple::le32: -+ return createPNaClTargetCodeGenInfo(CGM); - case llvm::Triple::m68k: - return createM68kTargetCodeGenInfo(CGM); - case llvm::Triple::mips: -diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp ---- a/clang/lib/CodeGen/ItaniumCXXABI.cpp -+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp -@@ -576,6 +576,13 @@ - return new XLCXXABI(CGM); - - case TargetCXXABI::GenericItanium: -+ if (CGM.getContext().getTargetInfo().getTriple().getArch() -+ == llvm::Triple::le32) { -+ // For PNaCl, use ARM-style method pointers so that PNaCl code -+ // does not assume anything about the alignment of function -+ // pointers. -+ return new ItaniumCXXABI(CGM, /*UseARMMethodPtrABI=*/true); -+ } - return new ItaniumCXXABI(CGM); - - case TargetCXXABI::Microsoft: -diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp ---- a/clang/lib/Driver/ToolChains/Clang.cpp -+++ b/clang/lib/Driver/ToolChains/Clang.cpp -@@ -3815,6 +3815,12 @@ - if (UseBuiltins) - A->render(Args, CmdArgs); - } -+ -+ // le32-specific flags: -+ // -fno-math-builtin: clang should not convert math builtins to intrinsics -+ // by default. -+ if (TC.getArch() == llvm::Triple::le32) -+ CmdArgs.push_back("-fno-math-builtin"); - } - - bool Driver::getDefaultModuleCachePath(SmallVectorImpl &Result) { -diff -ruN --strip-trailing-cr a/clang/lib/Parse/ParseExpr.cpp b/clang/lib/Parse/ParseExpr.cpp ---- a/clang/lib/Parse/ParseExpr.cpp -+++ b/clang/lib/Parse/ParseExpr.cpp -@@ -2343,9 +2343,10 @@ - } - - if (!LHS.isInvalid()) -- LHS = Actions.ActOnMemberAccessExpr( -- getCurScope(), LHS.get(), OpLoc, OpKind, SS, TemplateKWLoc, Name, -- CurParsedObjCImpl ? CurParsedObjCImpl->Dcl : nullptr); -+ LHS = Actions.ActOnMemberAccessExpr(getCurScope(), LHS.get(), OpLoc, -+ OpKind, SS, TemplateKWLoc, Name, -+ CurParsedObjCImpl ? CurParsedObjCImpl->Dcl -+ : nullptr); - if (!LHS.isInvalid()) { - if (Tok.is(tok::less)) - checkPotentialAngleBracket(LHS); -diff -ruN --strip-trailing-cr a/clang/lib/Parse/ParseExprCXX.cpp b/clang/lib/Parse/ParseExprCXX.cpp ---- a/clang/lib/Parse/ParseExprCXX.cpp -+++ b/clang/lib/Parse/ParseExprCXX.cpp -@@ -100,8 +100,7 @@ - bool MemberOfUnknownSpecialization; - if (!Actions.isTemplateName(getCurScope(), SS, /*hasTemplateKeyword=*/false, - TemplateName, ObjectType, EnteringContext, -- Template, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, /*MayBeNNS=*/true)) -+ Template, MemberOfUnknownSpecialization)) - return; - - FixDigraph(*this, PP, Next, SecondToken, tok::unknown, -@@ -354,8 +353,7 @@ - TemplateTy Template; - TemplateNameKind TNK = Actions.ActOnTemplateName( - getCurScope(), SS, TemplateKWLoc, TemplateName, ObjectType, -- EnteringContext, Template, /*AllowInjectedClassName*/ true, -- /*MayBeNNS=*/true); -+ EnteringContext, Template, /*AllowInjectedClassName*/ true); - if (AnnotateTemplateIdToken(Template, TNK, SS, TemplateKWLoc, - TemplateName, false)) - return true; -@@ -407,6 +405,7 @@ - : TemplateId->TemplateNameLoc; - SS.SetInvalid(SourceRange(StartLoc, CCLoc)); - } -+ - continue; - } - -@@ -529,19 +528,18 @@ - UnqualifiedId TemplateName; - TemplateName.setIdentifier(&II, Tok.getLocation()); - bool MemberOfUnknownSpecialization; -- if (TemplateNameKind TNK = Actions.isTemplateName( -- getCurScope(), SS, -- /*hasTemplateKeyword=*/false, TemplateName, ObjectType, -- EnteringContext, Template, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, -- /*MayBeNNS=*/true)) { -+ if (TemplateNameKind TNK = Actions.isTemplateName(getCurScope(), SS, -+ /*hasTemplateKeyword=*/false, -+ TemplateName, -+ ObjectType, -+ EnteringContext, -+ Template, -+ MemberOfUnknownSpecialization)) { - // If lookup didn't find anything, we treat the name as a template-name - // anyway. C++20 requires this, and in prior language modes it improves - // error recovery. But before we commit to this, check that we actually - // have something that looks like a template-argument-list next. -- if (!IsTypename && -- (TNK == TNK_Undeclared_template || -- (!HasScopeSpecifier && ObjectType)) && -+ if (!IsTypename && TNK == TNK_Undeclared_template && - isTemplateArgumentList(1) == TPResult::False) - break; - -@@ -568,7 +566,11 @@ - // member of an unknown specialization. However, this will only - // parse correctly as a template, so suggest the keyword 'template' - // before 'getAs' and treat this as a dependent template name. -- Diag(Tok.getLocation(), diag::ext_missing_dependent_template_keyword) -+ unsigned DiagID = diag::err_missing_dependent_template_keyword; -+ if (getLangOpts().MicrosoftExt) -+ DiagID = diag::warn_missing_dependent_template_keyword; -+ -+ Diag(Tok.getLocation(), DiagID) - << II.getName() - << FixItHint::CreateInsertion(Tok.getLocation(), "template "); - } -@@ -1918,12 +1920,12 @@ - // argument list. This affects examples such as - // void f(auto *p) { p->~X(); } - // ... but there's no ambiguity, and nowhere to write 'template' in such an -- // example, so we accept it anyway -- if (Tok.is(tok::less) && ParseUnqualifiedIdTemplateId( -- SS, ObjectType, Base && Base->containsErrors(), -- /*TemplateKWLoc=*/SourceLocation(), TildeLoc, -- Name, NameLoc, false, SecondTypeName, -- /*AssumeTemplateId=*/true)) -+ // example, so we accept it anyway. -+ if (Tok.is(tok::less) && -+ ParseUnqualifiedIdTemplateId( -+ SS, ObjectType, Base && Base->containsErrors(), SourceLocation(), -+ Name, NameLoc, false, SecondTypeName, -+ /*AssumeTemplateId=*/true)) - return ExprError(); - - return Actions.ActOnPseudoDestructorExpr(getCurScope(), Base, OpLoc, OpKind, -@@ -2530,9 +2532,8 @@ - /// \returns true if a parse error occurred, false otherwise. - bool Parser::ParseUnqualifiedIdTemplateId( - CXXScopeSpec &SS, ParsedType ObjectType, bool ObjectHadErrors, -- SourceLocation TemplateKWLoc, SourceLocation TildeLoc, IdentifierInfo *Name, -- SourceLocation NameLoc, bool EnteringContext, UnqualifiedId &Id, -- bool AssumeTemplateId) { -+ SourceLocation TemplateKWLoc, IdentifierInfo *Name, SourceLocation NameLoc, -+ bool EnteringContext, UnqualifiedId &Id, bool AssumeTemplateId) { - assert(Tok.is(tok::less) && "Expected '<' to finish parsing a template-id"); - - TemplateTy Template; -@@ -2546,14 +2547,13 @@ - // this template-id is used to form a nested-name-specifier or not. - TNK = Actions.ActOnTemplateName(getCurScope(), SS, TemplateKWLoc, Id, - ObjectType, EnteringContext, Template, -- /*AllowInjectedClassName=*/true, -- TildeLoc.isValid()); -+ /*AllowInjectedClassName*/ true); - } else { - bool MemberOfUnknownSpecialization; -- TNK = Actions.isTemplateName( -- getCurScope(), SS, TemplateKWLoc.isValid(), Id, ObjectType, -- EnteringContext, Template, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, TildeLoc.isValid()); -+ TNK = Actions.isTemplateName(getCurScope(), SS, -+ TemplateKWLoc.isValid(), Id, -+ ObjectType, EnteringContext, Template, -+ MemberOfUnknownSpecialization); - // If lookup found nothing but we're assuming that this is a template - // name, double-check that makes sense syntactically before committing - // to it. -@@ -2580,13 +2580,13 @@ - else - Name += Id.Identifier->getName(); - } -- Diag(Id.StartLocation, diag::ext_missing_dependent_template_keyword) -+ Diag(Id.StartLocation, diag::err_missing_dependent_template_keyword) - << Name - << FixItHint::CreateInsertion(Id.StartLocation, "template "); - } - TNK = Actions.ActOnTemplateName( - getCurScope(), SS, TemplateKWLoc, Id, ObjectType, EnteringContext, -- Template, /*AllowInjectedClassName=*/true, TildeLoc.isValid()); -+ Template, /*AllowInjectedClassName*/ true); - } else if (TNK == TNK_Non_template) { - return false; - } -@@ -2611,16 +2611,14 @@ - bool MemberOfUnknownSpecialization; - TemplateName.setIdentifier(Name, NameLoc); - if (ObjectType) { -- TNK = Actions.ActOnTemplateName(getCurScope(), SS, TemplateKWLoc, -- TemplateName, ObjectType, EnteringContext, -- Template, /*AllowInjectedClassName=*/true, -- /*MayBeNNS=*/true); -+ TNK = Actions.ActOnTemplateName( -+ getCurScope(), SS, TemplateKWLoc, TemplateName, ObjectType, -+ EnteringContext, Template, /*AllowInjectedClassName*/ true); - } else { - TNK = Actions.isTemplateName(getCurScope(), SS, TemplateKWLoc.isValid(), -- TemplateName, ObjectType, EnteringContext, -- Template, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, -- /*MayBeNNS=*/true); -+ TemplateName, ObjectType, -+ EnteringContext, Template, -+ MemberOfUnknownSpecialization); - - if (TNK == TNK_Non_template && !Id.DestructorName.get()) { - Diag(NameLoc, diag::err_destructor_template_id) -@@ -2682,7 +2680,7 @@ - if (Id.getKind() == UnqualifiedIdKind::IK_ConstructorName) - Id.setConstructorName(Type.get(), NameLoc, RAngleLoc); - else -- Id.setDestructorName(TildeLoc, Type.get(), RAngleLoc); -+ Id.setDestructorName(Id.StartLocation, Type.get(), RAngleLoc); - - return false; - } -@@ -3030,9 +3028,8 @@ - if (Tok.is(tok::less)) - return ParseUnqualifiedIdTemplateId( - SS, ObjectType, ObjectHadErrors, -- TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), -- /*TildeLoc=*/SourceLocation(), Id, IdLoc, EnteringContext, Result, -- TemplateSpecified); -+ TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), Id, IdLoc, -+ EnteringContext, Result, TemplateSpecified); - - if (TemplateSpecified) { - TemplateNameKind TNK = -@@ -3127,15 +3124,13 @@ - Tok.is(tok::less)) - return ParseUnqualifiedIdTemplateId( - SS, ObjectType, ObjectHadErrors, -- TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), -- /*TildeLoc=*/SourceLocation(), /*Name=*/nullptr, -- /*NameLoc=*/SourceLocation(), EnteringContext, Result, -- TemplateSpecified); -+ TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), nullptr, -+ SourceLocation(), EnteringContext, Result, TemplateSpecified); - else if (TemplateSpecified && - Actions.ActOnTemplateName( - getCurScope(), SS, *TemplateKWLoc, Result, ObjectType, - EnteringContext, Template, -- /*AllowInjectedClassName=*/true) == TNK_Non_template) -+ /*AllowInjectedClassName*/ true) == TNK_Non_template) - return true; - - return false; -@@ -3225,8 +3220,8 @@ - Result.setDestructorName(TildeLoc, nullptr, ClassNameLoc); - return ParseUnqualifiedIdTemplateId( - SS, ObjectType, ObjectHadErrors, -- TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), TildeLoc, -- ClassName, ClassNameLoc, EnteringContext, Result, TemplateSpecified); -+ TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), ClassName, -+ ClassNameLoc, EnteringContext, Result, TemplateSpecified); - } - - // Note that this is a destructor name. -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp ---- a/clang/lib/Sema/SemaCoroutine.cpp -+++ b/clang/lib/Sema/SemaCoroutine.cpp -@@ -306,8 +306,8 @@ - // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. - CXXScopeSpec SS; - ExprResult Result = S.BuildMemberReferenceExpr( -- Base, Base->getType(), Loc, /*IsPtr=*/false, SS, SourceLocation(), -- NameInfo, /*TemplateArgs=*/nullptr, -+ Base, Base->getType(), Loc, /*IsPtr=*/false, SS, -+ SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr, - /*Scope=*/nullptr); - if (Result.isInvalid()) - return ExprError(); -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp ---- a/clang/lib/Sema/SemaCXXScopeSpec.cpp -+++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp -@@ -356,41 +356,29 @@ - return false; - } - --/// If the given nested-name-specifier begins with a bare identifier --/// (e.g., Base::), perform name lookup for that identifier as a --/// nested-name-specifier within the given scope, and return the result of that --/// name lookup. --bool Sema::LookupFirstQualifierInScope(Scope *S, NestedNameSpecifier *NNS, -- UnresolvedSetImpl &R) { -- if (!S) -- return false; -+NamedDecl *Sema::FindFirstQualifierInScope(Scope *S, NestedNameSpecifier *NNS) { -+ if (!S || !NNS) -+ return nullptr; - - while (NNS->getPrefix()) - NNS = NNS->getPrefix(); - -- // FIXME: This is a rather nasty hack! Ideally we should get the results -- // from LookupTemplateName/BuildCXXNestedNameSpecifier. -- const IdentifierInfo *II = NNS->getAsIdentifier(); -- if (!II) { -- if (const auto *DTST = -- dyn_cast_if_present( -- NNS->getAsType())) -- II = DTST->getIdentifier(); -- else -- return false; -- } -- assert(II && "Missing first qualifier in scope"); -- LookupResult Found(*this, II, SourceLocation(), -- NNS->getAsIdentifier() ? LookupNestedNameSpecifierName -- : LookupOrdinaryName); -+ if (NNS->getKind() != NestedNameSpecifier::Identifier) -+ return nullptr; -+ -+ LookupResult Found(*this, NNS->getAsIdentifier(), SourceLocation(), -+ LookupNestedNameSpecifierName); - LookupName(Found, S); -+ assert(!Found.isAmbiguous() && "Cannot handle ambiguities here yet"); - -- if (Found.empty()) -- return false; -+ if (!Found.isSingleResult()) -+ return nullptr; - -- R.addAllDecls(Found.asUnresolvedSet().pairs()); -- Found.suppressDiagnostics(); -- return true; -+ NamedDecl *Result = Found.getFoundDecl(); -+ if (isAcceptableNestedNameSpecifier(Result)) -+ return Result; -+ -+ return nullptr; - } - - namespace { -@@ -419,82 +407,112 @@ - - bool Sema::BuildCXXNestedNameSpecifier(Scope *S, NestedNameSpecInfo &IdInfo, - bool EnteringContext, CXXScopeSpec &SS, -+ NamedDecl *ScopeLookupResult, - bool ErrorRecoveryLookup, - bool *IsCorrectedToColon, - bool OnlyNamespace) { - if (IdInfo.Identifier->isEditorPlaceholder()) - return true; -- if (IsCorrectedToColon) -- *IsCorrectedToColon = false; -- -- QualType ObjectType = GetTypeFromParser(IdInfo.ObjectType); - LookupResult Found(*this, IdInfo.Identifier, IdInfo.IdentifierLoc, - OnlyNamespace ? LookupNamespaceName - : LookupNestedNameSpecifierName); -+ QualType ObjectType = GetTypeFromParser(IdInfo.ObjectType); - -- // C++ [basic.lookup.qual.general]p3: -- // Qualified name lookup in a class, namespace, or enumeration performs a -- // search of the scope associated with it except as specified below. -- LookupParsedName(Found, S, &SS, ObjectType, -- /*AllowBuiltinCreation=*/false, EnteringContext); -- -- // C++ [basic.lookup.qual.general]p3: -- // [...] Unless otherwise specified, a qualified name undergoes qualified -- // name lookup in its lookup context from the point where it appears unless -- // the lookup context either is dependent and is not the current -- // instantiation or is not a class or class template. -- if (Found.wasNotFoundInCurrentInstantiation()) { -- // Don't speculate if we're just trying to improve error recovery. -- if (ErrorRecoveryLookup) -- return true; -- -- // The lookup context is dependent and either: -- // - it is not the current instantiation, or -- // - it is the current instantiation, it has at least one dependent base -- // class, and qualified lookup found nothing. -- // Build a dependent nested-name-specifier. We will lookup the name again -- // during instantiation. -- SS.Extend(Context, IdInfo.Identifier, IdInfo.IdentifierLoc, IdInfo.CCLoc); -- return false; -+ // Determine where to perform name lookup -+ DeclContext *LookupCtx = nullptr; -+ bool isDependent = false; -+ if (IsCorrectedToColon) -+ *IsCorrectedToColon = false; -+ if (!ObjectType.isNull()) { -+ // This nested-name-specifier occurs in a member access expression, e.g., -+ // x->B::f, and we are looking into the type of the object. -+ assert(!SS.isSet() && "ObjectType and scope specifier cannot coexist"); -+ LookupCtx = computeDeclContext(ObjectType); -+ isDependent = ObjectType->isDependentType(); -+ } else if (SS.isSet()) { -+ // This nested-name-specifier occurs after another nested-name-specifier, -+ // so look into the context associated with the prior nested-name-specifier. -+ LookupCtx = computeDeclContext(SS, EnteringContext); -+ isDependent = isDependentScopeSpecifier(SS); -+ Found.setContextRange(SS.getRange()); - } - - bool ObjectTypeSearchedInScope = false; -+ if (LookupCtx) { -+ // Perform "qualified" name lookup into the declaration context we -+ // computed, which is either the type of the base of a member access -+ // expression or the declaration context associated with a prior -+ // nested-name-specifier. -+ -+ // The declaration context must be complete. -+ if (!LookupCtx->isDependentContext() && -+ RequireCompleteDeclContext(SS, LookupCtx)) -+ return true; - -- // C++ [basic.lookup.qual.general]p2: -- // A member-qualified name is the (unique) component name, if any, of -- // - an unqualified-id or -- // - a nested-name-specifier of the form type-name :: or namespace-name :: -- // in the id-expression of a class member access expression. -- // -- // C++ [basic.lookup.qual.general]p3: -- // [...] If nothing is found by qualified lookup for a member-qualified -- // name that is the terminal name of a nested-name-specifier and is not -- // dependent, it undergoes unqualified lookup. -- // -- // In 'x.A::B::y', 'A' will undergo unqualified lookup if qualified lookup -- // in the type of 'x' finds nothing. If the lookup context is dependent, -- // we perform the unqualified lookup in the template definition context -- // and store the results so we can replicate the lookup during instantiation. -- if (Found.empty() && !ObjectType.isNull()) { -- if (S) { -- LookupName(Found, S); -- } else if (!SS.getUnqualifiedLookups().empty()) { -- Found.addAllDecls(SS.getUnqualifiedLookups()); -- Found.resolveKind(); -+ LookupQualifiedName(Found, LookupCtx); -+ -+ if (!ObjectType.isNull() && Found.empty()) { -+ // C++ [basic.lookup.classref]p4: -+ // If the id-expression in a class member access is a qualified-id of -+ // the form -+ // -+ // class-name-or-namespace-name::... -+ // -+ // the class-name-or-namespace-name following the . or -> operator is -+ // looked up both in the context of the entire postfix-expression and in -+ // the scope of the class of the object expression. If the name is found -+ // only in the scope of the class of the object expression, the name -+ // shall refer to a class-name. If the name is found only in the -+ // context of the entire postfix-expression, the name shall refer to a -+ // class-name or namespace-name. [...] -+ // -+ // Qualified name lookup into a class will not find a namespace-name, -+ // so we do not need to diagnose that case specifically. However, -+ // this qualified name lookup may find nothing. In that case, perform -+ // unqualified name lookup in the given scope (if available) or -+ // reconstruct the result from when name lookup was performed at template -+ // definition time. -+ if (S) -+ LookupName(Found, S); -+ else if (ScopeLookupResult) -+ Found.addDecl(ScopeLookupResult); -+ -+ ObjectTypeSearchedInScope = true; - } -- ObjectTypeSearchedInScope = true; -+ } else if (!isDependent) { -+ // Perform unqualified name lookup in the current scope. -+ LookupName(Found, S); - } - - if (Found.isAmbiguous()) - return true; - -+ // If we performed lookup into a dependent context and did not find anything, -+ // that's fine: just build a dependent nested-name-specifier. -+ if (Found.empty() && isDependent && -+ !(LookupCtx && LookupCtx->isRecord() && -+ (!cast(LookupCtx)->hasDefinition() || -+ !cast(LookupCtx)->hasAnyDependentBases()))) { -+ // Don't speculate if we're just trying to improve error recovery. -+ if (ErrorRecoveryLookup) -+ return true; -+ -+ // We were not able to compute the declaration context for a dependent -+ // base object type or prior nested-name-specifier, so this -+ // nested-name-specifier refers to an unknown specialization. Just build -+ // a dependent nested-name-specifier. -+ SS.Extend(Context, IdInfo.Identifier, IdInfo.IdentifierLoc, IdInfo.CCLoc); -+ return false; -+ } -+ - if (Found.empty() && !ErrorRecoveryLookup) { - // If identifier is not found as class-name-or-namespace-name, but is found - // as other entity, don't look for typos. - LookupResult R(*this, Found.getLookupNameInfo(), LookupOrdinaryName); -- LookupParsedName(R, S, &SS, ObjectType, -- /*AllowBuiltinCreation=*/false, EnteringContext); -- -+ if (LookupCtx) -+ LookupQualifiedName(R, LookupCtx); -+ else if (S && !isDependent) -+ LookupName(R, S); - if (!R.empty()) { - // Don't diagnose problems with this speculative lookup. - R.suppressDiagnostics(); -@@ -521,11 +539,6 @@ - } - } - -- DeclContext *LookupCtx = -- SS.isSet() -- ? computeDeclContext(SS, EnteringContext) -- : (!ObjectType.isNull() ? computeDeclContext(ObjectType) : nullptr); -- - if (Found.empty() && !ErrorRecoveryLookup && !getLangOpts().MSVCCompat) { - // We haven't found anything, and we're not recovering from a - // different kind of error, so look for typos. -@@ -581,14 +594,14 @@ - // scope, reconstruct the result from the template instantiation itself. - // - // Note that C++11 does *not* perform this redundant lookup. -- NamedDecl *OuterDecl = nullptr; -+ NamedDecl *OuterDecl; - if (S) { - LookupResult FoundOuter(*this, IdInfo.Identifier, IdInfo.IdentifierLoc, - LookupNestedNameSpecifierName); - LookupName(FoundOuter, S); - OuterDecl = FoundOuter.getAsSingle(); -- } else if (!SS.getUnqualifiedLookups().empty()) -- OuterDecl = SS.getUnqualifiedLookups().front().getDecl(); -+ } else -+ OuterDecl = ScopeLookupResult; - - if (isAcceptableNestedNameSpecifier(OuterDecl) && - OuterDecl->getCanonicalDecl() != SD->getCanonicalDecl() && -@@ -766,7 +779,7 @@ - return true; - - return BuildCXXNestedNameSpecifier(S, IdInfo, EnteringContext, SS, -- /*ErrorRecoveryLookup=*/false, -+ /*ScopeLookupResult=*/nullptr, false, - IsCorrectedToColon, OnlyNamespace); - } - -@@ -827,7 +840,7 @@ - return false; - - return !BuildCXXNestedNameSpecifier(S, IdInfo, EnteringContext, SS, -- /*ErrorRecoveryLookup=*/true); -+ /*ScopeLookupResult=*/nullptr, true); - } - - bool Sema::ActOnCXXNestedNameSpecifier(Scope *S, -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaDeclCXX.cpp b/clang/lib/Sema/SemaDeclCXX.cpp ---- a/clang/lib/Sema/SemaDeclCXX.cpp -+++ b/clang/lib/Sema/SemaDeclCXX.cpp -@@ -1275,11 +1275,9 @@ - if (UseMemberGet) { - // if [lookup of member get] finds at least one declaration, the - // initializer is e.get(). -- E = S.BuildMemberReferenceExpr(E.get(), DecompType, Loc, -- /*IsArrow=*/false, -- /*SS=*/CXXScopeSpec(), -- /*TemplateKWLoc=*/SourceLocation(), -- MemberGet, &Args, /*S=*/nullptr); -+ E = S.BuildMemberReferenceExpr(E.get(), DecompType, Loc, false, -+ CXXScopeSpec(), SourceLocation(), nullptr, -+ MemberGet, &Args, nullptr); - if (E.isInvalid()) - return true; - -@@ -4903,12 +4901,16 @@ - MemberLookup.addDecl(Indirect ? cast(Indirect) - : cast(Field), AS_public); - MemberLookup.resolveKind(); -- ExprResult CtorArg = SemaRef.BuildMemberReferenceExpr( -- MemberExprBase, ParamType, Loc, -- /*IsArrow=*/false, SS, -- /*TemplateKWLoc=*/SourceLocation(), MemberLookup, -- /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ ExprResult CtorArg -+ = SemaRef.BuildMemberReferenceExpr(MemberExprBase, -+ ParamType, Loc, -+ /*IsArrow=*/false, -+ SS, -+ /*TemplateKWLoc=*/SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ MemberLookup, -+ /*TemplateArgs=*/nullptr, -+ /*S*/nullptr); - if (CtorArg.isInvalid()) - return true; - -@@ -14334,10 +14336,8 @@ - public: - Expr *build(Sema &S, SourceLocation Loc) const override { - return assertNotNull(S.BuildMemberReferenceExpr( -- Builder.build(S, Loc), Type, Loc, IsArrow, SS, -- /*TemplateKwLoc=*/SourceLocation(), MemberLookup, -- /*TemplateArgs=*/nullptr, /*S=*/nullptr) -- .get()); -+ Builder.build(S, Loc), Type, Loc, IsArrow, SS, SourceLocation(), -+ nullptr, MemberLookup, nullptr, nullptr).get()); - } - - MemberBuilder(const ExprBuilder &Builder, QualType Type, bool IsArrow, -@@ -14543,11 +14543,13 @@ - Loc); - - // Create the reference to operator=. -- ExprResult OpEqualRef = S.BuildMemberReferenceExpr( -- To.build(S, Loc), T, Loc, /*IsArrow=*/false, SS, -- /*TemplateKWLoc=*/SourceLocation(), OpLookup, -- /*TemplateArgs=*/nullptr, /*S*/ nullptr, -- /*SuppressQualifierCheck=*/true); -+ ExprResult OpEqualRef -+ = S.BuildMemberReferenceExpr(To.build(S, Loc), T, Loc, /*IsArrow=*/false, -+ SS, /*TemplateKWLoc=*/SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ OpLookup, -+ /*TemplateArgs=*/nullptr, /*S*/nullptr, -+ /*SuppressQualifierCheck=*/true); - if (OpEqualRef.isInvalid()) - return StmtError(); - -@@ -17153,9 +17155,8 @@ - - auto BuildExpr = [&](LookupResult &LR) { - ExprResult Res = BuildMemberReferenceExpr( -- Message, Message->getType(), Message->getBeginLoc(), /*IsArrow=*/false, -- /*SS=*/CXXScopeSpec(), /*TemplateKWLoc=*/SourceLocation(), LR, -- /*TemplateArgs=*/nullptr, /*S=*/nullptr); -+ Message, Message->getType(), Message->getBeginLoc(), false, -+ CXXScopeSpec(), SourceLocation(), nullptr, LR, nullptr, nullptr); - if (Res.isInvalid()) - return ExprError(); - Res = BuildCallExpr(nullptr, Res.get(), Loc, std::nullopt, Loc, nullptr, -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp ---- a/clang/lib/Sema/SemaExpr.cpp -+++ b/clang/lib/Sema/SemaExpr.cpp -@@ -2624,7 +2624,7 @@ - return CXXDependentScopeMemberExpr::Create( - Context, /*This=*/nullptr, ThisType, /*IsArrow=*/true, - /*Op=*/SourceLocation(), NestedNameSpecifierLoc(), TemplateKWLoc, -- /*UnqualifiedLookups=*/std::nullopt, NameInfo, TemplateArgs); -+ /*FirstQualifierFoundInScope=*/nullptr, NameInfo, TemplateArgs); - } - - // Synthesize a fake NNS that points to the derived class. This will -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprMember.cpp b/clang/lib/Sema/SemaExprMember.cpp ---- a/clang/lib/Sema/SemaExprMember.cpp -+++ b/clang/lib/Sema/SemaExprMember.cpp -@@ -552,9 +552,11 @@ - } - - ExprResult --Sema::ActOnDependentMemberExpr(Expr *BaseExpr, QualType BaseType, bool IsArrow, -- SourceLocation OpLoc, const CXXScopeSpec &SS, -+Sema::ActOnDependentMemberExpr(Expr *BaseExpr, QualType BaseType, -+ bool IsArrow, SourceLocation OpLoc, -+ const CXXScopeSpec &SS, - SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, - const DeclarationNameInfo &NameInfo, - const TemplateArgumentListInfo *TemplateArgs) { - // Even in dependent contexts, try to diagnose base expressions with -@@ -588,8 +590,8 @@ - // must have pointer type, and the accessed type is the pointee. - return CXXDependentScopeMemberExpr::Create( - Context, BaseExpr, BaseType, IsArrow, OpLoc, -- SS.getWithLocInContext(Context), TemplateKWLoc, -- SS.getUnqualifiedLookups(), NameInfo, TemplateArgs); -+ SS.getWithLocInContext(Context), TemplateKWLoc, FirstQualifierInScope, -+ NameInfo, TemplateArgs); - } - - /// We know that the given qualified member reference points only to -@@ -765,9 +767,8 @@ - R.addDecl(ND); - R.resolveKind(); - return SemaRef.BuildMemberReferenceExpr( -- BaseExpr, BaseExpr->getType(), OpLoc, IsArrow, SS, -- /*TemplateKWLoc=*/SourceLocation(), R, /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ BaseExpr, BaseExpr->getType(), OpLoc, IsArrow, SS, SourceLocation(), -+ nullptr, R, nullptr, nullptr); - }, - Sema::CTK_ErrorRecovery, DC); - -@@ -783,7 +784,7 @@ - ExprResult Sema::BuildMemberReferenceExpr( - Expr *Base, QualType BaseType, SourceLocation OpLoc, bool IsArrow, - CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -- const DeclarationNameInfo &NameInfo, -+ NamedDecl *FirstQualifierInScope, const DeclarationNameInfo &NameInfo, - const TemplateArgumentListInfo *TemplateArgs, const Scope *S, - ActOnMemberAccessExtraArgs *ExtraArgs) { - LookupResult R(*this, NameInfo, LookupMemberName); -@@ -827,9 +828,10 @@ - if (SS.isInvalid()) - return ExprError(); - -- return BuildMemberReferenceExpr(Base, BaseType, OpLoc, IsArrow, SS, -- TemplateKWLoc, R, TemplateArgs, S, -- /*SuppressQualifierCheck=*/false, ExtraArgs); -+ return BuildMemberReferenceExpr(Base, BaseType, -+ OpLoc, IsArrow, SS, TemplateKWLoc, -+ FirstQualifierInScope, R, TemplateArgs, S, -+ false, ExtraArgs); - } - - ExprResult -@@ -967,11 +969,17 @@ - return false; - } - --ExprResult Sema::BuildMemberReferenceExpr( -- Expr *BaseExpr, QualType BaseExprType, SourceLocation OpLoc, bool IsArrow, -- const CXXScopeSpec &SS, SourceLocation TemplateKWLoc, LookupResult &R, -- const TemplateArgumentListInfo *TemplateArgs, const Scope *S, -- bool SuppressQualifierCheck, ActOnMemberAccessExtraArgs *ExtraArgs) { -+ExprResult -+Sema::BuildMemberReferenceExpr(Expr *BaseExpr, QualType BaseExprType, -+ SourceLocation OpLoc, bool IsArrow, -+ const CXXScopeSpec &SS, -+ SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, -+ LookupResult &R, -+ const TemplateArgumentListInfo *TemplateArgs, -+ const Scope *S, -+ bool SuppressQualifierCheck, -+ ActOnMemberAccessExtraArgs *ExtraArgs) { - assert(!SS.isInvalid() && "nested-name-specifier cannot be invalid"); - // If the member wasn't found in the current instantiation, or if the - // arrow operator was used with a dependent non-pointer object expression, -@@ -981,8 +989,8 @@ - (SS.isSet() ? SS.getScopeRep()->isDependent() - : BaseExprType->isDependentType()))) - return ActOnDependentMemberExpr(BaseExpr, BaseExprType, IsArrow, OpLoc, SS, -- TemplateKWLoc, R.getLookupNameInfo(), -- TemplateArgs); -+ TemplateKWLoc, FirstQualifierInScope, -+ R.getLookupNameInfo(), TemplateArgs); - - QualType BaseType = BaseExprType; - if (IsArrow) { -@@ -1187,9 +1195,9 @@ - - // Non-dependent member, but dependent template arguments. - if (!VDecl.get()) -- return ActOnDependentMemberExpr(BaseExpr, BaseExpr->getType(), IsArrow, -- OpLoc, SS, TemplateKWLoc, MemberNameInfo, -- TemplateArgs); -+ return ActOnDependentMemberExpr( -+ BaseExpr, BaseExpr->getType(), IsArrow, OpLoc, SS, TemplateKWLoc, -+ FirstQualifierInScope, MemberNameInfo, TemplateArgs); - - VarDecl *Var = cast(VDecl.get()); - if (!Var->getTemplateSpecializationKind()) -@@ -1755,16 +1763,15 @@ - const TemplateArgumentListInfo *TemplateArgs; - DecomposeUnqualifiedId(Id, TemplateArgsBuffer, - NameInfo, TemplateArgs); -- bool IsArrow = OpKind == tok::arrow; -+ -+ bool IsArrow = (OpKind == tok::arrow); - - if (getLangOpts().HLSL && IsArrow) - return ExprError(Diag(OpLoc, diag::err_hlsl_operator_unsupported) << 2); - -- UnresolvedSet<4> UnqualifiedLookups; -- if (SS.isValid() && -- LookupFirstQualifierInScope(S, SS.getScopeRep(), UnqualifiedLookups)) { -- SS.setUnqualifiedLookups(UnqualifiedLookups.pairs()); -- } -+ NamedDecl *FirstQualifierInScope -+ = (!SS.isSet() ? nullptr : FindFirstQualifierInScope(S, SS.getScopeRep())); -+ - // This is a postfix expression, so get rid of ParenListExprs. - ExprResult Result = MaybeConvertParenListExprToParenExpr(S, Base); - if (Result.isInvalid()) return ExprError(); -@@ -1772,8 +1779,8 @@ - - ActOnMemberAccessExtraArgs ExtraArgs = {S, Id, ObjCImpDecl}; - ExprResult Res = BuildMemberReferenceExpr( -- Base, Base->getType(), OpLoc, IsArrow, SS, TemplateKWLoc, NameInfo, -- TemplateArgs, S, &ExtraArgs); -+ Base, Base->getType(), OpLoc, IsArrow, SS, TemplateKWLoc, -+ FirstQualifierInScope, NameInfo, TemplateArgs, S, &ExtraArgs); - - if (!Res.isInvalid() && isa(Res.get())) - CheckMemberAccessOfNoDeref(cast(Res.get())); -@@ -1917,8 +1924,9 @@ - baseExpr = BuildCXXThisExpr(loc, ThisTy, /*IsImplicit=*/true); - } - -- return BuildMemberReferenceExpr(baseExpr, ThisTy, -- /*OpLoc=*/SourceLocation(), -- /*IsArrow=*/!getLangOpts().HLSL, SS, -- TemplateKWLoc, R, TemplateArgs, S); -+ return BuildMemberReferenceExpr( -+ baseExpr, ThisTy, -+ /*OpLoc=*/SourceLocation(), -+ /*IsArrow=*/!getLangOpts().HLSL, SS, TemplateKWLoc, -+ /*FirstQualifierInScope=*/nullptr, R, TemplateArgs, S); - } -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp ---- a/clang/lib/Sema/SemaOverload.cpp -+++ b/clang/lib/Sema/SemaOverload.cpp -@@ -16043,11 +16043,13 @@ - - CandidateSet->clear(OverloadCandidateSet::CSK_Normal); - if (!MemberLookup.empty()) { -- ExprResult MemberRef = BuildMemberReferenceExpr( -- Range, Range->getType(), Loc, -- /*IsPtr=*/false, /*SS=*/CXXScopeSpec(), -- /*TemplateKWLoc=*/SourceLocation(), MemberLookup, -- /*TemplateArgs=*/nullptr, S); -+ ExprResult MemberRef = -+ BuildMemberReferenceExpr(Range, Range->getType(), Loc, -+ /*IsPtr=*/false, CXXScopeSpec(), -+ /*TemplateKWLoc=*/SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ MemberLookup, -+ /*TemplateArgs=*/nullptr, S); - if (MemberRef.isInvalid()) { - *CallExpr = ExprError(); - return FRS_DiagnosticIssued; -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaStmtAsm.cpp b/clang/lib/Sema/SemaStmtAsm.cpp ---- a/clang/lib/Sema/SemaStmtAsm.cpp -+++ b/clang/lib/Sema/SemaStmtAsm.cpp -@@ -900,8 +900,7 @@ - return CXXDependentScopeMemberExpr::Create( - Context, E, T, /*IsArrow=*/false, AsmLoc, NestedNameSpecifierLoc(), - SourceLocation(), -- /*UnqualifiedLookups=*/std::nullopt, NameInfo, -- /*TemplateArgs=*/nullptr); -+ /*FirstQualifierFoundInScope=*/nullptr, NameInfo, /*TemplateArgs=*/nullptr); - } - - const RecordType *RT = T->getAs(); -@@ -924,9 +923,8 @@ - - // Make an Expr to thread through OpDecl. - ExprResult Result = BuildMemberReferenceExpr( -- E, E->getType(), AsmLoc, /*IsArrow=*/false, /*SS=*/CXXScopeSpec(), -- /*TemplateKWLoc*/ SourceLocation(), FieldResult, -- /*TemplateArgs=*/nullptr, /*S=*/nullptr); -+ E, E->getType(), AsmLoc, /*IsArrow=*/false, CXXScopeSpec(), -+ SourceLocation(), nullptr, FieldResult, nullptr, nullptr); - - return Result; - } -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp ---- a/clang/lib/Sema/SemaTemplate.cpp -+++ b/clang/lib/Sema/SemaTemplate.cpp -@@ -174,12 +174,15 @@ - return false; - } - --TemplateNameKind --Sema::isTemplateName(Scope *S, CXXScopeSpec &SS, bool hasTemplateKeyword, -- const UnqualifiedId &Name, ParsedType ObjectTypePtr, -- bool EnteringContext, TemplateTy &TemplateResult, -- bool &MemberOfUnknownSpecialization, bool Disambiguation, -- bool MayBeNNS) { -+TemplateNameKind Sema::isTemplateName(Scope *S, -+ CXXScopeSpec &SS, -+ bool hasTemplateKeyword, -+ const UnqualifiedId &Name, -+ ParsedType ObjectTypePtr, -+ bool EnteringContext, -+ TemplateTy &TemplateResult, -+ bool &MemberOfUnknownSpecialization, -+ bool Disambiguation) { - assert(getLangOpts().CPlusPlus && "No template names in C!"); - - DeclarationName TName; -@@ -210,9 +213,8 @@ - if (LookupTemplateName(R, S, SS, ObjectType, EnteringContext, - /*RequiredTemplate=*/SourceLocation(), - &AssumedTemplate, -- /*AllowTypoCorrection=*/!Disambiguation, MayBeNNS)) -+ /*AllowTypoCorrection=*/!Disambiguation)) - return TNK_Non_template; -- - MemberOfUnknownSpecialization = R.wasNotFoundInCurrentInstantiation(); - - if (AssumedTemplate != AssumedTemplateKind::None) { -@@ -378,7 +380,7 @@ - QualType ObjectType, bool EnteringContext, - RequiredTemplateKind RequiredTemplate, - AssumedTemplateKind *ATK, -- bool AllowTypoCorrection, bool MayBeNNS) { -+ bool AllowTypoCorrection) { - if (ATK) - *ATK = AssumedTemplateKind::None; - -@@ -387,89 +389,92 @@ - - Found.setTemplateNameLookup(true); - -- // Template names cannot appear inside an Objective-C class or object type -- // or a vector type. -- // -- // FIXME: This is wrong. For example: -- // -- // template using Vec = T __attribute__((ext_vector_type(4))); -- // Vec vi; -- // vi.Vec::~Vec(); -- // -- // ... should be accepted but we will not treat 'Vec' as a template name -- // here. The right thing to do would be to check if the name is a valid -- // vector component name, and look up a template name if not. And similarly -- // for lookups into Objective-C class and object types, where the same -- // problem can arise. -- if (!ObjectType.isNull() && (ObjectType->isVectorType() || -- ObjectType->isObjCObjectOrInterfaceType())) { -- Found.clear(); -- return false; -- } -+ // Determine where to perform name lookup -+ DeclContext *LookupCtx = nullptr; -+ bool IsDependent = false; -+ if (!ObjectType.isNull()) { -+ // This nested-name-specifier occurs in a member access expression, e.g., -+ // x->B::f, and we are looking into the type of the object. -+ assert(SS.isEmpty() && "ObjectType and scope specifier cannot coexist"); -+ LookupCtx = computeDeclContext(ObjectType); -+ IsDependent = !LookupCtx && ObjectType->isDependentType(); -+ assert((IsDependent || !ObjectType->isIncompleteType() || -+ !ObjectType->getAs() || -+ ObjectType->castAs()->isBeingDefined()) && -+ "Caller should have completed object type"); - -- LookupParsedName(Found, S, &SS, ObjectType, -- /*AllowBuiltinCreation=*/false, EnteringContext); -+ // Template names cannot appear inside an Objective-C class or object type -+ // or a vector type. -+ // -+ // FIXME: This is wrong. For example: -+ // -+ // template using Vec = T __attribute__((ext_vector_type(4))); -+ // Vec vi; -+ // vi.Vec::~Vec(); -+ // -+ // ... should be accepted but we will not treat 'Vec' as a template name -+ // here. The right thing to do would be to check if the name is a valid -+ // vector component name, and look up a template name if not. And similarly -+ // for lookups into Objective-C class and object types, where the same -+ // problem can arise. -+ if (ObjectType->isObjCObjectOrInterfaceType() || -+ ObjectType->isVectorType()) { -+ Found.clear(); -+ return false; -+ } -+ } else if (SS.isNotEmpty()) { -+ // This nested-name-specifier occurs after another nested-name-specifier, -+ // so long into the context associated with the prior nested-name-specifier. -+ LookupCtx = computeDeclContext(SS, EnteringContext); -+ IsDependent = !LookupCtx && isDependentScopeSpecifier(SS); - -- // C++ [basic.lookup.qual.general]p3: -- // [...] Unless otherwise specified, a qualified name undergoes qualified -- // name lookup in its lookup context from the point where it appears unless -- // the lookup context either is dependent and is not the current -- // instantiation or is not a class or class template. -- // -- // The lookup context is dependent and either: -- // - it is not the current instantiation, or -- // - it is the current instantiation, it has at least one dependent base -- // class, and qualified lookup found nothing. -- // -- // If this is a member-qualified name that is the terminal name of a -- // nested-name-specifier, we perform unqualified lookup and store the results -- // so we can replicate the lookup during instantiation. The results of the -- // unqualified loookup are *not* used to determine whether '<' is interpreted -- // as the delimiter of a template-argument-list. -- // -- // For example: -- // -- // template -- // struct A { -- // int x; -- // }; -- // -- // template -- // using B = A; -- // -- // template -- // void f(A a, A b) { -- // a.B::x; // error: missing 'template' before 'B' -- // b.B::x; // ok, lookup context is not dependent -- // } -- if (Found.wasNotFoundInCurrentInstantiation()) -- return false; -+ // The declaration context must be complete. -+ if (LookupCtx && RequireCompleteDeclContext(SS, LookupCtx)) -+ return true; -+ } - - bool ObjectTypeSearchedInScope = false; -- -- // C++ [basic.lookup.qual.general]p2: -- // A member-qualified name is the (unique) component name, if any, of -- // - an unqualified-id or -- // - a nested-name-specifier of the form type-name :: or namespace-name :: -- // in the id-expression of a class member access expression. -- // -- // C++ [basic.lookup.qual.general]p3: -- // [...] If nothing is found by qualified lookup for a member-qualified -- // name that is the terminal name of a nested-name-specifier and is not -- // dependent, it undergoes unqualified lookup. -- // -- // In 'x.A::B::y', 'A' will undergo unqualified lookup if qualified lookup -- // in the type of 'x' finds nothing. If the lookup context is dependent, -- // we perform the unqualified lookup in the template definition context -- // and store the results so we can replicate the lookup during instantiation. -- if (MayBeNNS && Found.empty() && !ObjectType.isNull()) { -- if (S) { -+ bool AllowFunctionTemplatesInLookup = true; -+ if (LookupCtx) { -+ // Perform "qualified" name lookup into the declaration context we -+ // computed, which is either the type of the base of a member access -+ // expression or the declaration context associated with a prior -+ // nested-name-specifier. -+ LookupQualifiedName(Found, LookupCtx); -+ -+ // FIXME: The C++ standard does not clearly specify what happens in the -+ // case where the object type is dependent, and implementations vary. In -+ // Clang, we treat a name after a . or -> as a template-name if lookup -+ // finds a non-dependent member or member of the current instantiation that -+ // is a type template, or finds no such members and lookup in the context -+ // of the postfix-expression finds a type template. In the latter case, the -+ // name is nonetheless dependent, and we may resolve it to a member of an -+ // unknown specialization when we come to instantiate the template. -+ IsDependent |= Found.wasNotFoundInCurrentInstantiation(); -+ } -+ -+ if (SS.isEmpty() && (ObjectType.isNull() || Found.empty())) { -+ // C++ [basic.lookup.classref]p1: -+ // In a class member access expression (5.2.5), if the . or -> token is -+ // immediately followed by an identifier followed by a <, the -+ // identifier must be looked up to determine whether the < is the -+ // beginning of a template argument list (14.2) or a less-than operator. -+ // The identifier is first looked up in the class of the object -+ // expression. If the identifier is not found, it is then looked up in -+ // the context of the entire postfix-expression and shall name a class -+ // template. -+ if (S) - LookupName(Found, S); -- } else if (!SS.getUnqualifiedLookups().empty()) { -- Found.addAllDecls(SS.getUnqualifiedLookups()); -- Found.resolveKind(); -+ -+ if (!ObjectType.isNull()) { -+ // FIXME: We should filter out all non-type templates here, particularly -+ // variable templates and concepts. But the exclusion of alias templates -+ // and template template parameters is a wording defect. -+ AllowFunctionTemplatesInLookup = false; -+ ObjectTypeSearchedInScope = true; - } -- ObjectTypeSearchedInScope = true; -+ -+ IsDependent |= Found.wasNotFoundInCurrentInstantiation(); - } - - if (Found.isAmbiguous()) -@@ -489,7 +494,7 @@ - getLangOpts().CPlusPlus20 && llvm::all_of(Found, [](NamedDecl *ND) { - return isa(ND->getUnderlyingDecl()); - }); -- if (AllFunctions || Found.empty()) { -+ if (AllFunctions || (Found.empty() && !IsDependent)) { - // If lookup found any functions, or if this is a name that can only be - // used for a function, then strongly assume this is a function - // template-id. -@@ -501,15 +506,11 @@ - } - } - -- if (Found.empty() && AllowTypoCorrection) { -+ if (Found.empty() && !IsDependent && AllowTypoCorrection) { - // If we did not find any names, and this is not a disambiguation, attempt - // to correct any typos. - DeclarationName Name = Found.getLookupName(); - Found.clear(); -- DeclContext *LookupCtx = -- SS.isSet() -- ? computeDeclContext(SS, EnteringContext) -- : (!ObjectType.isNull() ? computeDeclContext(ObjectType) : nullptr); - // Simple filter callback that, for keywords, only accepts the C++ *_cast - DefaultFilterCCC FilterCCC{}; - FilterCCC.WantTypeSpecifiers = false; -@@ -542,8 +543,13 @@ - - NamedDecl *ExampleLookupResult = - Found.empty() ? nullptr : Found.getRepresentativeDecl(); -- FilterAcceptableTemplateNames(Found); -+ FilterAcceptableTemplateNames(Found, AllowFunctionTemplatesInLookup); - if (Found.empty()) { -+ if (IsDependent) { -+ Found.setNotFoundInCurrentInstantiation(); -+ return false; -+ } -+ - // If a 'template' keyword was used, a lookup that finds only non-template - // names is an error. - if (ExampleLookupResult && RequiredTemplate) { -@@ -735,7 +741,7 @@ - /*IsArrow=*/!Context.getLangOpts().HLSL, - /*OperatorLoc=*/SourceLocation(), - /*QualifierLoc=*/NestedNameSpecifierLoc(), TemplateKWLoc, -- /*UnqualifiedLookups=*/std::nullopt, NameInfo, TemplateArgs); -+ /*FirstQualifierFoundInScope=*/nullptr, NameInfo, TemplateArgs); - } - return BuildDependentDeclRefExpr(SS, TemplateKWLoc, NameInfo, TemplateArgs); - } -@@ -5849,10 +5855,14 @@ - return BuildTemplateIdExpr(SS, TemplateKWLoc, R, /*ADL=*/false, TemplateArgs); - } - --TemplateNameKind Sema::ActOnTemplateName( -- Scope *S, CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -- const UnqualifiedId &Name, ParsedType ObjectType, bool EnteringContext, -- TemplateTy &Result, bool AllowInjectedClassName, bool MayBeNNS) { -+TemplateNameKind Sema::ActOnTemplateName(Scope *S, -+ CXXScopeSpec &SS, -+ SourceLocation TemplateKWLoc, -+ const UnqualifiedId &Name, -+ ParsedType ObjectType, -+ bool EnteringContext, -+ TemplateTy &Result, -+ bool AllowInjectedClassName) { - if (TemplateKWLoc.isValid() && S && !S->getTemplateParamParent()) - Diag(TemplateKWLoc, - getLangOpts().CPlusPlus11 ? -@@ -5887,10 +5897,9 @@ - // "template" keyword is now permitted). We follow the C++0x - // rules, even in C++03 mode with a warning, retroactively applying the DR. - bool MemberOfUnknownSpecialization; -- TemplateNameKind TNK = -- isTemplateName(S, SS, TemplateKWLoc.isValid(), Name, ObjectType, -- EnteringContext, Result, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, MayBeNNS); -+ TemplateNameKind TNK = isTemplateName(S, SS, TemplateKWLoc.isValid(), Name, -+ ObjectType, EnteringContext, Result, -+ MemberOfUnknownSpecialization); - if (TNK != TNK_Non_template) { - // We resolved this to a (non-dependent) template name. Return it. - auto *LookupRD = dyn_cast_or_null(LookupCtx); -@@ -5929,8 +5938,7 @@ - ? RequiredTemplateKind(TemplateKWLoc) - : TemplateNameIsRequired; - if (!LookupTemplateName(R, S, SS, ObjectType.get(), EnteringContext, RTK, -- /*ATK=*/nullptr, /*AllowTypoCorrection=*/false, -- MayBeNNS) && -+ /*ATK=*/nullptr, /*AllowTypoCorrection=*/false) && - !R.isAmbiguous()) { - if (LookupCtx) - Diag(Name.getBeginLoc(), diag::err_no_member) -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp ---- a/clang/lib/Sema/SemaTemplateInstantiate.cpp -+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp -@@ -1515,11 +1515,12 @@ - NestedNameSpecifierLoc QualifierLoc, - QualType T); - -- TemplateName TransformTemplateName(CXXScopeSpec &SS, TemplateName Name, -- SourceLocation NameLoc, -- QualType ObjectType = QualType(), -- bool AllowInjectedClassName = false, -- bool MayBeNNS = false); -+ TemplateName -+ TransformTemplateName(CXXScopeSpec &SS, TemplateName Name, -+ SourceLocation NameLoc, -+ QualType ObjectType = QualType(), -+ NamedDecl *FirstQualifierInScope = nullptr, -+ bool AllowInjectedClassName = false); - - const CXXAssumeAttr *TransformCXXAssumeAttr(const CXXAssumeAttr *AA); - const LoopHintAttr *TransformLoopHintAttr(const LoopHintAttr *LH); -@@ -1951,7 +1952,8 @@ - - TemplateName TemplateInstantiator::TransformTemplateName( - CXXScopeSpec &SS, TemplateName Name, SourceLocation NameLoc, -- QualType ObjectType, bool AllowInjectedClassName, bool MayBeNNS) { -+ QualType ObjectType, NamedDecl *FirstQualifierInScope, -+ bool AllowInjectedClassName) { - if (TemplateTemplateParmDecl *TTP - = dyn_cast_or_null(Name.getAsTemplateDecl())) { - if (TTP->getDepth() < TemplateArgs.getNumLevels()) { -@@ -2023,7 +2025,8 @@ - } - - return inherited::TransformTemplateName(SS, Name, NameLoc, ObjectType, -- AllowInjectedClassName, MayBeNNS); -+ FirstQualifierInScope, -+ AllowInjectedClassName); - } - - ExprResult -diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h ---- a/clang/lib/Sema/TreeTransform.h -+++ b/clang/lib/Sema/TreeTransform.h -@@ -541,9 +541,10 @@ - /// By default, transforms all of the types and declarations within the - /// nested-name-specifier. Subclasses may override this function to provide - /// alternate behavior. -- NestedNameSpecifierLoc TransformNestedNameSpecifierLoc( -- NestedNameSpecifierLoc NNS, QualType ObjectType = QualType(), -- ArrayRef UnqualifiedLookups = std::nullopt); -+ NestedNameSpecifierLoc -+ TransformNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS, -+ QualType ObjectType = QualType(), -+ NamedDecl *FirstQualifierInScope = nullptr); - - /// Transform the given declaration name. - /// -@@ -584,11 +585,12 @@ - /// By default, transforms the template name by transforming the declarations - /// and nested-name-specifiers that occur within the template name. - /// Subclasses may override this function to provide alternate behavior. -- TemplateName TransformTemplateName(CXXScopeSpec &SS, TemplateName Name, -- SourceLocation NameLoc, -- QualType ObjectType = QualType(), -- bool AllowInjectedClassName = false, -- bool MayBeNNS = false); -+ TemplateName -+ TransformTemplateName(CXXScopeSpec &SS, TemplateName Name, -+ SourceLocation NameLoc, -+ QualType ObjectType = QualType(), -+ NamedDecl *FirstQualifierInScope = nullptr, -+ bool AllowInjectedClassName = false); - - /// Transform the given template argument. - /// -@@ -1138,8 +1140,8 @@ - CXXScopeSpec SS; - SS.Adopt(QualifierLoc); - TemplateName InstName = getDerived().RebuildTemplateName( -- SS, TemplateKWLoc, *Name, NameLoc, QualType(), AllowInjectedClassName, -- /*MayBeNNS=*/false); -+ SS, TemplateKWLoc, *Name, NameLoc, QualType(), nullptr, -+ AllowInjectedClassName); - - if (InstName.isNull()) - return QualType(); -@@ -1310,7 +1312,8 @@ - SourceLocation TemplateKWLoc, - const IdentifierInfo &Name, - SourceLocation NameLoc, QualType ObjectType, -- bool AllowInjectedClassName, bool MayBeNNS); -+ NamedDecl *FirstQualifierInScope, -+ bool AllowInjectedClassName); - - /// Build a new template name given a nested name specifier and the - /// overloaded operator name that is referred to as a template. -@@ -2846,14 +2849,15 @@ - /// - /// By default, performs semantic analysis to build the new expression. - /// Subclasses may override this routine to provide different behavior. -- ExprResult -- RebuildMemberExpr(Expr *Base, SourceLocation OpLoc, bool isArrow, -- NestedNameSpecifierLoc QualifierLoc, -- SourceLocation TemplateKWLoc, -- const DeclarationNameInfo &MemberNameInfo, -- ValueDecl *Member, NamedDecl *FoundDecl, -- const TemplateArgumentListInfo *ExplicitTemplateArgs, -- ArrayRef UnqualifiedLookups) { -+ ExprResult RebuildMemberExpr(Expr *Base, SourceLocation OpLoc, -+ bool isArrow, -+ NestedNameSpecifierLoc QualifierLoc, -+ SourceLocation TemplateKWLoc, -+ const DeclarationNameInfo &MemberNameInfo, -+ ValueDecl *Member, -+ NamedDecl *FoundDecl, -+ const TemplateArgumentListInfo *ExplicitTemplateArgs, -+ NamedDecl *FirstQualifierInScope) { - ExprResult BaseResult = getSema().PerformMemberExprBaseConversion(Base, - isArrow); - if (!Member->getDeclName()) { -@@ -2890,7 +2894,6 @@ - - CXXScopeSpec SS; - SS.Adopt(QualifierLoc); -- SS.setUnqualifiedLookups(UnqualifiedLookups); - - Base = BaseResult.get(); - if (Base->containsErrors()) -@@ -2923,9 +2926,10 @@ - } - - return getSema().BuildMemberReferenceExpr(Base, BaseType, OpLoc, isArrow, -- SS, TemplateKWLoc, R, -- ExplicitTemplateArgs, -- /*S=*/nullptr); -+ SS, TemplateKWLoc, -+ FirstQualifierInScope, -+ R, ExplicitTemplateArgs, -+ /*S*/nullptr); - } - - /// Build a new binary operator expression. -@@ -2998,9 +3002,10 @@ - CXXScopeSpec SS; - DeclarationNameInfo NameInfo(&Accessor, AccessorLoc); - return getSema().BuildMemberReferenceExpr( -- Base, Base->getType(), OpLoc, IsArrow, SS, -- /*TemplateKWLoc=*/SourceLocation(), NameInfo, -- /*TemplateArgs=*/nullptr, /*S=*/nullptr); -+ Base, Base->getType(), OpLoc, IsArrow, SS, SourceLocation(), -+ /*FirstQualifierInScope*/ nullptr, NameInfo, -+ /* TemplateArgs */ nullptr, -+ /*S*/ nullptr); - } - - /// Build a new initializer list expression. -@@ -3568,37 +3573,46 @@ - /// - /// By default, performs semantic analysis to build the new expression. - /// Subclasses may override this routine to provide different behavior. -- ExprResult RebuildCXXDependentScopeMemberExpr( -- Expr *BaseE, QualType BaseType, bool IsArrow, SourceLocation OperatorLoc, -- NestedNameSpecifierLoc QualifierLoc, SourceLocation TemplateKWLoc, -- ArrayRef UnqualifiedLookups, -- const DeclarationNameInfo &MemberNameInfo, -- const TemplateArgumentListInfo *TemplateArgs) { -+ ExprResult RebuildCXXDependentScopeMemberExpr(Expr *BaseE, -+ QualType BaseType, -+ bool IsArrow, -+ SourceLocation OperatorLoc, -+ NestedNameSpecifierLoc QualifierLoc, -+ SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, -+ const DeclarationNameInfo &MemberNameInfo, -+ const TemplateArgumentListInfo *TemplateArgs) { - CXXScopeSpec SS; - SS.Adopt(QualifierLoc); -- SS.setUnqualifiedLookups(UnqualifiedLookups); - -- return SemaRef.BuildMemberReferenceExpr( -- BaseE, BaseType, OperatorLoc, IsArrow, SS, TemplateKWLoc, -- MemberNameInfo, TemplateArgs, /*S=*/nullptr); -+ return SemaRef.BuildMemberReferenceExpr(BaseE, BaseType, -+ OperatorLoc, IsArrow, -+ SS, TemplateKWLoc, -+ FirstQualifierInScope, -+ MemberNameInfo, -+ TemplateArgs, /*S*/nullptr); - } - - /// Build a new member reference expression. - /// - /// By default, performs semantic analysis to build the new expression. - /// Subclasses may override this routine to provide different behavior. -- ExprResult RebuildUnresolvedMemberExpr( -- Expr *BaseE, QualType BaseType, SourceLocation OperatorLoc, bool IsArrow, -- NestedNameSpecifierLoc QualifierLoc, SourceLocation TemplateKWLoc, -- ArrayRef UnqualifiedLookups, LookupResult &R, -- const TemplateArgumentListInfo *TemplateArgs) { -+ ExprResult RebuildUnresolvedMemberExpr(Expr *BaseE, QualType BaseType, -+ SourceLocation OperatorLoc, -+ bool IsArrow, -+ NestedNameSpecifierLoc QualifierLoc, -+ SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, -+ LookupResult &R, -+ const TemplateArgumentListInfo *TemplateArgs) { - CXXScopeSpec SS; - SS.Adopt(QualifierLoc); -- SS.setUnqualifiedLookups(UnqualifiedLookups); - -- return SemaRef.BuildMemberReferenceExpr(BaseE, BaseType, OperatorLoc, -- IsArrow, SS, TemplateKWLoc, R, -- TemplateArgs, /*S=*/nullptr); -+ return SemaRef.BuildMemberReferenceExpr(BaseE, BaseType, -+ OperatorLoc, IsArrow, -+ SS, TemplateKWLoc, -+ FirstQualifierInScope, -+ R, TemplateArgs, /*S*/nullptr); - } - - /// Build a new noexcept expression. -@@ -3817,8 +3831,10 @@ - DeclarationNameInfo NameInfo(Ivar->getDeclName(), IvarLoc); - ExprResult Result = getSema().BuildMemberReferenceExpr( - BaseArg, BaseArg->getType(), -- /*FIXME:*/ IvarLoc, IsArrow, SS, /*TemplateKWLoc=*/SourceLocation(), -- NameInfo, /*TemplateArgs=*/nullptr, /*S=*/nullptr); -+ /*FIXME:*/ IvarLoc, IsArrow, SS, SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, NameInfo, -+ /*TemplateArgs=*/nullptr, -+ /*S=*/nullptr); - if (IsFreeIvar && Result.isUsable()) - cast(Result.get())->setIsFreeIvar(IsFreeIvar); - return Result; -@@ -3833,12 +3849,14 @@ - SourceLocation PropertyLoc) { - CXXScopeSpec SS; - DeclarationNameInfo NameInfo(Property->getDeclName(), PropertyLoc); -- return getSema().BuildMemberReferenceExpr( -- BaseArg, BaseArg->getType(), -- /*FIXME:*/ PropertyLoc, -- /*IsArrow=*/false, SS, /*TemplateKWLoc=*/SourceLocation(), NameInfo, -- /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ return getSema().BuildMemberReferenceExpr(BaseArg, BaseArg->getType(), -+ /*FIXME:*/PropertyLoc, -+ /*IsArrow=*/false, -+ SS, SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ NameInfo, -+ /*TemplateArgs=*/nullptr, -+ /*S=*/nullptr); - } - - /// Build a new Objective-C property reference expression. -@@ -3865,11 +3883,13 @@ - SourceLocation OpLoc, bool IsArrow) { - CXXScopeSpec SS; - DeclarationNameInfo NameInfo(&getSema().Context.Idents.get("isa"), IsaLoc); -- return getSema().BuildMemberReferenceExpr( -- BaseArg, BaseArg->getType(), OpLoc, IsArrow, SS, -- /*TemplateKWLoc=*/SourceLocation(), NameInfo, -- /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ return getSema().BuildMemberReferenceExpr(BaseArg, BaseArg->getType(), -+ OpLoc, IsArrow, -+ SS, SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ NameInfo, -+ /*TemplateArgs=*/nullptr, -+ /*S=*/nullptr); - } - - /// Build a new shuffle vector expression. -@@ -4034,14 +4054,18 @@ - } - - private: -- TypeLoc TransformTypeInObjectScope(TypeLoc TL, QualType ObjectType, -+ TypeLoc TransformTypeInObjectScope(TypeLoc TL, -+ QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, - CXXScopeSpec &SS); - - TypeSourceInfo *TransformTypeInObjectScope(TypeSourceInfo *TSInfo, - QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, - CXXScopeSpec &SS); - - TypeSourceInfo *TransformTSIInObjectScope(TypeLoc TL, QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, - CXXScopeSpec &SS); - - QualType TransformDependentNameType(TypeLocBuilder &TLB, -@@ -4360,7 +4384,7 @@ - template - NestedNameSpecifierLoc TreeTransform::TransformNestedNameSpecifierLoc( - NestedNameSpecifierLoc NNS, QualType ObjectType, -- ArrayRef UnqualifiedLookups) { -+ NamedDecl *FirstQualifierInScope) { - SmallVector Qualifiers; - - auto insertNNS = [&Qualifiers](NestedNameSpecifierLoc NNS) { -@@ -4371,8 +4395,6 @@ - insertNNS(NNS); - - CXXScopeSpec SS; -- SS.setUnqualifiedLookups(UnqualifiedLookups); -- - while (!Qualifiers.empty()) { - NestedNameSpecifierLoc Q = Qualifiers.pop_back_val(); - NestedNameSpecifier *QNNS = Q.getNestedNameSpecifier(); -@@ -4382,9 +4404,8 @@ - Sema::NestedNameSpecInfo IdInfo(QNNS->getAsIdentifier(), - Q.getLocalBeginLoc(), Q.getLocalEndLoc(), - ObjectType); -- if (SemaRef.BuildCXXNestedNameSpecifier(/*Scope=*/nullptr, IdInfo, -- /*EnteringContext=*/false, SS, -- /*ErrorRecoveryLookup=*/false)) -+ if (SemaRef.BuildCXXNestedNameSpecifier(/*Scope=*/nullptr, IdInfo, false, -+ SS, FirstQualifierInScope, false)) - return NestedNameSpecifierLoc(); - break; - } -@@ -4422,7 +4443,8 @@ - - case NestedNameSpecifier::TypeSpecWithTemplate: - case NestedNameSpecifier::TypeSpec: { -- TypeLoc TL = TransformTypeInObjectScope(Q.getTypeLoc(), ObjectType, SS); -+ TypeLoc TL = TransformTypeInObjectScope(Q.getTypeLoc(), ObjectType, -+ FirstQualifierInScope, SS); - - if (!TL) - return NestedNameSpecifierLoc(); -@@ -4455,7 +4477,7 @@ - } - - // The qualifier-in-scope and object type only apply to the leftmost entity. -- SS.setUnqualifiedLookups(std::nullopt); -+ FirstQualifierInScope = nullptr; - ObjectType = QualType(); - } - -@@ -4538,10 +4560,14 @@ - llvm_unreachable("Unknown name kind."); - } - --template --TemplateName TreeTransform::TransformTemplateName( -- CXXScopeSpec &SS, TemplateName Name, SourceLocation NameLoc, -- QualType ObjectType, bool AllowInjectedClassName, bool MayBeNNS) { -+template -+TemplateName -+TreeTransform::TransformTemplateName(CXXScopeSpec &SS, -+ TemplateName Name, -+ SourceLocation NameLoc, -+ QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, -+ bool AllowInjectedClassName) { - if (QualifiedTemplateName *QTN = Name.getAsQualifiedTemplateName()) { - TemplateDecl *Template = QTN->getUnderlyingTemplate().getAsTemplateDecl(); - assert(Template && "qualified template name must refer to a template"); -@@ -4565,7 +4591,7 @@ - if (SS.getScopeRep()) { - // These apply to the scope specifier, not the template. - ObjectType = QualType(); -- SS.setUnqualifiedLookups(std::nullopt); -+ FirstQualifierInScope = nullptr; - } - - if (!getDerived().AlwaysRebuild() && -@@ -4577,9 +4603,13 @@ - SourceLocation TemplateKWLoc = NameLoc; - - if (DTN->isIdentifier()) { -- return getDerived().RebuildTemplateName( -- SS, TemplateKWLoc, *DTN->getIdentifier(), NameLoc, ObjectType, -- AllowInjectedClassName, MayBeNNS); -+ return getDerived().RebuildTemplateName(SS, -+ TemplateKWLoc, -+ *DTN->getIdentifier(), -+ NameLoc, -+ ObjectType, -+ FirstQualifierInScope, -+ AllowInjectedClassName); - } - - return getDerived().RebuildTemplateName(SS, TemplateKWLoc, -@@ -5123,31 +5153,39 @@ - return SemaRef.BuildQualifiedType(T, Loc, Quals); - } - --template --TypeLoc TreeTransform::TransformTypeInObjectScope(TypeLoc TL, -- QualType ObjectType, -- CXXScopeSpec &SS) { -+template -+TypeLoc -+TreeTransform::TransformTypeInObjectScope(TypeLoc TL, -+ QualType ObjectType, -+ NamedDecl *UnqualLookup, -+ CXXScopeSpec &SS) { - if (getDerived().AlreadyTransformed(TL.getType())) - return TL; - -- TypeSourceInfo *TSI = TransformTSIInObjectScope(TL, ObjectType, SS); -+ TypeSourceInfo *TSI = -+ TransformTSIInObjectScope(TL, ObjectType, UnqualLookup, SS); - if (TSI) - return TSI->getTypeLoc(); - return TypeLoc(); - } - --template --TypeSourceInfo *TreeTransform::TransformTypeInObjectScope( -- TypeSourceInfo *TSInfo, QualType ObjectType, CXXScopeSpec &SS) { -+template -+TypeSourceInfo * -+TreeTransform::TransformTypeInObjectScope(TypeSourceInfo *TSInfo, -+ QualType ObjectType, -+ NamedDecl *UnqualLookup, -+ CXXScopeSpec &SS) { - if (getDerived().AlreadyTransformed(TSInfo->getType())) - return TSInfo; - -- return TransformTSIInObjectScope(TSInfo->getTypeLoc(), ObjectType, SS); -+ return TransformTSIInObjectScope(TSInfo->getTypeLoc(), ObjectType, -+ UnqualLookup, SS); - } - - template - TypeSourceInfo *TreeTransform::TransformTSIInObjectScope( -- TypeLoc TL, QualType ObjectType, CXXScopeSpec &SS) { -+ TypeLoc TL, QualType ObjectType, NamedDecl *UnqualLookup, -+ CXXScopeSpec &SS) { - QualType T = TL.getType(); - assert(!getDerived().AlreadyTransformed(T)); - -@@ -5160,7 +5198,7 @@ - - TemplateName Template = getDerived().TransformTemplateName( - SS, SpecTL.getTypePtr()->getTemplateName(), SpecTL.getTemplateNameLoc(), -- ObjectType, /*AllowInjectedClassName=*/true, /*MayBeNNS=*/true); -+ ObjectType, UnqualLookup, /*AllowInjectedClassName*/true); - if (Template.isNull()) - return nullptr; - -@@ -5170,11 +5208,13 @@ - DependentTemplateSpecializationTypeLoc SpecTL = - TL.castAs(); - -- TemplateName Template = getDerived().RebuildTemplateName( -- SS, SpecTL.getTemplateKeywordLoc(), -- *SpecTL.getTypePtr()->getIdentifier(), SpecTL.getTemplateNameLoc(), -- ObjectType, -- /*AllowInjectedClassName=*/true, /*MayBeNNS=*/true); -+ TemplateName Template -+ = getDerived().RebuildTemplateName(SS, -+ SpecTL.getTemplateKeywordLoc(), -+ *SpecTL.getTypePtr()->getIdentifier(), -+ SpecTL.getTemplateNameLoc(), -+ ObjectType, UnqualLookup, -+ /*AllowInjectedClassName*/true); - if (Template.isNull()) - return nullptr; - -@@ -12318,8 +12358,7 @@ - // first-qualifier-in-scope here, just in case we had a dependent - // base (and therefore couldn't do the check) and a - // nested-name-qualifier (and therefore could do the lookup). -- ArrayRef UnqualifiedLookups; -- -+ NamedDecl *FirstQualifierInScope = nullptr; - DeclarationNameInfo MemberNameInfo = E->getMemberNameInfo(); - if (MemberNameInfo.getName()) { - MemberNameInfo = getDerived().TransformDeclarationNameInfo(MemberNameInfo); -@@ -12327,11 +12366,16 @@ - return ExprError(); - } - -- return getDerived().RebuildMemberExpr( -- Base.get(), FakeOperatorLoc, E->isArrow(), QualifierLoc, TemplateKWLoc, -- MemberNameInfo, Member, FoundDecl, -- (E->hasExplicitTemplateArgs() ? &TransArgs : nullptr), -- UnqualifiedLookups); -+ return getDerived().RebuildMemberExpr(Base.get(), FakeOperatorLoc, -+ E->isArrow(), -+ QualifierLoc, -+ TemplateKWLoc, -+ MemberNameInfo, -+ Member, -+ FoundDecl, -+ (E->hasExplicitTemplateArgs() -+ ? &TransArgs : nullptr), -+ FirstQualifierInScope); - } - - template -@@ -13458,8 +13502,9 @@ - - PseudoDestructorTypeStorage Destroyed; - if (E->getDestroyedTypeInfo()) { -- TypeSourceInfo *DestroyedTypeInfo = getDerived().TransformTypeInObjectScope( -- E->getDestroyedTypeInfo(), ObjectType, SS); -+ TypeSourceInfo *DestroyedTypeInfo -+ = getDerived().TransformTypeInObjectScope(E->getDestroyedTypeInfo(), -+ ObjectType, nullptr, SS); - if (!DestroyedTypeInfo) - return ExprError(); - Destroyed = DestroyedTypeInfo; -@@ -13485,7 +13530,7 @@ - if (E->getScopeTypeInfo()) { - CXXScopeSpec EmptySS; - ScopeTypeInfo = getDerived().TransformTypeInObjectScope( -- E->getScopeTypeInfo(), ObjectType, EmptySS); -+ E->getScopeTypeInfo(), ObjectType, nullptr, EmptySS); - if (!ScopeTypeInfo) - return ExprError(); - } -@@ -14746,17 +14791,19 @@ - ObjectType = BaseType->castAs()->getPointeeType(); - } - -- UnresolvedSet<4> UnqualifiedLookups; -- for (auto D : E->unqualified_lookups()) { -- if (NamedDecl *InstD = getDerived().TransformFirstQualifierInScope( -- D.getDecl(), E->getQualifierLoc().getBeginLoc())) -- UnqualifiedLookups.addDecl(InstD); -- } -+ // Transform the first part of the nested-name-specifier that qualifies -+ // the member name. -+ NamedDecl *FirstQualifierInScope -+ = getDerived().TransformFirstQualifierInScope( -+ E->getFirstQualifierFoundInScope(), -+ E->getQualifierLoc().getBeginLoc()); - - NestedNameSpecifierLoc QualifierLoc; - if (E->getQualifier()) { -- QualifierLoc = getDerived().TransformNestedNameSpecifierLoc( -- E->getQualifierLoc(), ObjectType, UnqualifiedLookups.pairs()); -+ QualifierLoc -+ = getDerived().TransformNestedNameSpecifierLoc(E->getQualifierLoc(), -+ ObjectType, -+ FirstQualifierInScope); - if (!QualifierLoc) - return ExprError(); - } -@@ -14775,16 +14822,23 @@ - if (!E->hasExplicitTemplateArgs()) { - // This is a reference to a member without an explicitly-specified - // template argument list. Optimize for this common case. -- if (!getDerived().AlwaysRebuild() && Base.get() == OldBase && -- BaseType == E->getBaseType() && QualifierLoc == E->getQualifierLoc() && -+ if (!getDerived().AlwaysRebuild() && -+ Base.get() == OldBase && -+ BaseType == E->getBaseType() && -+ QualifierLoc == E->getQualifierLoc() && - NameInfo.getName() == E->getMember() && -- UnqualifiedLookups.pairs() == E->unqualified_lookups()) -+ FirstQualifierInScope == E->getFirstQualifierFoundInScope()) - return E; - -- return getDerived().RebuildCXXDependentScopeMemberExpr( -- Base.get(), BaseType, E->isArrow(), E->getOperatorLoc(), QualifierLoc, -- TemplateKWLoc, UnqualifiedLookups.pairs(), NameInfo, -- /*TemplateArgs*/ nullptr); -+ return getDerived().RebuildCXXDependentScopeMemberExpr(Base.get(), -+ BaseType, -+ E->isArrow(), -+ E->getOperatorLoc(), -+ QualifierLoc, -+ TemplateKWLoc, -+ FirstQualifierInScope, -+ NameInfo, -+ /*TemplateArgs*/nullptr); - } - - TemplateArgumentListInfo TransArgs(E->getLAngleLoc(), E->getRAngleLoc()); -@@ -14793,9 +14847,15 @@ - TransArgs)) - return ExprError(); - -- return getDerived().RebuildCXXDependentScopeMemberExpr( -- Base.get(), BaseType, E->isArrow(), E->getOperatorLoc(), QualifierLoc, -- TemplateKWLoc, UnqualifiedLookups.pairs(), NameInfo, &TransArgs); -+ return getDerived().RebuildCXXDependentScopeMemberExpr(Base.get(), -+ BaseType, -+ E->isArrow(), -+ E->getOperatorLoc(), -+ QualifierLoc, -+ TemplateKWLoc, -+ FirstQualifierInScope, -+ NameInfo, -+ &TransArgs); - } - - template -@@ -14856,11 +14916,11 @@ - // first-qualifier-in-scope here, just in case we had a dependent - // base (and therefore couldn't do the check) and a - // nested-name-qualifier (and therefore could do the lookup). -- ArrayRef UnqualifiedLookups; -+ NamedDecl *FirstQualifierInScope = nullptr; - - return getDerived().RebuildUnresolvedMemberExpr( - Base.get(), BaseType, Old->getOperatorLoc(), Old->isArrow(), QualifierLoc, -- TemplateKWLoc, UnqualifiedLookups, R, -+ TemplateKWLoc, FirstQualifierInScope, R, - (Old->hasExplicitTemplateArgs() ? &TransArgs : nullptr)); - } - -@@ -16217,18 +16277,22 @@ - TemplateName(Template)); - } - --template --TemplateName TreeTransform::RebuildTemplateName( -- CXXScopeSpec &SS, SourceLocation TemplateKWLoc, const IdentifierInfo &Name, -- SourceLocation NameLoc, QualType ObjectType, bool AllowInjectedClassName, -- bool MayBeNNS) { -+template -+TemplateName -+TreeTransform::RebuildTemplateName(CXXScopeSpec &SS, -+ SourceLocation TemplateKWLoc, -+ const IdentifierInfo &Name, -+ SourceLocation NameLoc, -+ QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, -+ bool AllowInjectedClassName) { - UnqualifiedId TemplateName; - TemplateName.setIdentifier(&Name, NameLoc); - Sema::TemplateTy Template; - getSema().ActOnTemplateName(/*Scope=*/nullptr, SS, TemplateKWLoc, - TemplateName, ParsedType::make(ObjectType), - /*EnteringContext=*/false, Template, -- AllowInjectedClassName, MayBeNNS); -+ AllowInjectedClassName); - return Template.get(); - } - -@@ -16376,10 +16440,13 @@ - } - - SourceLocation TemplateKWLoc; // FIXME: retrieve it from caller. -- return getSema().BuildMemberReferenceExpr( -- Base, BaseType, OperatorLoc, isArrow, SS, TemplateKWLoc, NameInfo, -- /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ return getSema().BuildMemberReferenceExpr(Base, BaseType, -+ OperatorLoc, isArrow, -+ SS, TemplateKWLoc, -+ /*FIXME: FirstQualifier*/ nullptr, -+ NameInfo, -+ /*TemplateArgs*/ nullptr, -+ /*S*/nullptr); - } - - template -diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ---- a/clang/lib/Serialization/ASTReaderStmt.cpp -+++ b/clang/lib/Serialization/ASTReaderStmt.cpp -@@ -1993,43 +1993,42 @@ - CXXDependentScopeMemberExpr *E) { - VisitExpr(E); - -- CurrentUnpackingBits.emplace(Record.readInt()); -- bool HasQualifier = CurrentUnpackingBits->getNextBit(); -- bool HasTemplateInfo = CurrentUnpackingBits->getNextBit(); -- unsigned NumUnqualifiedLookups = Record.readInt(); - unsigned NumTemplateArgs = Record.readInt(); -- E->CXXDependentScopeMemberExprBits.HasQualifier = HasQualifier; -- E->CXXDependentScopeMemberExprBits.NumUnqualifiedLookups = -- NumUnqualifiedLookups; -- E->CXXDependentScopeMemberExprBits.HasTemplateKWAndArgsInfo = HasTemplateInfo; -+ CurrentUnpackingBits.emplace(Record.readInt()); -+ bool HasTemplateKWAndArgsInfo = CurrentUnpackingBits->getNextBit(); -+ bool HasFirstQualifierFoundInScope = CurrentUnpackingBits->getNextBit(); -+ -+ assert((HasTemplateKWAndArgsInfo == E->hasTemplateKWAndArgsInfo()) && -+ "Wrong HasTemplateKWAndArgsInfo!"); -+ assert( -+ (HasFirstQualifierFoundInScope == E->hasFirstQualifierFoundInScope()) && -+ "Wrong HasFirstQualifierFoundInScope!"); -+ -+ if (HasTemplateKWAndArgsInfo) -+ ReadTemplateKWAndArgsInfo( -+ *E->getTrailingObjects(), -+ E->getTrailingObjects(), NumTemplateArgs); -+ -+ assert((NumTemplateArgs == E->getNumTemplateArgs()) && -+ "Wrong NumTemplateArgs!"); - -- E->BaseType = Record.readType(); - E->CXXDependentScopeMemberExprBits.IsArrow = - CurrentUnpackingBits->getNextBit(); - -+ E->BaseType = Record.readType(); -+ E->QualifierLoc = Record.readNestedNameSpecifierLoc(); -+ // not ImplicitAccess - if (CurrentUnpackingBits->getNextBit()) - E->Base = Record.readSubExpr(); - else - E->Base = nullptr; - -- E->OperatorLoc = Record.readSourceLocation(); -- E->MemberNameInfo = Record.readDeclarationNameInfo(); -+ E->CXXDependentScopeMemberExprBits.OperatorLoc = readSourceLocation(); - -- if (HasQualifier) -- new (E->getTrailingObjects()) -- NestedNameSpecifierLoc(Record.readNestedNameSpecifierLoc()); -- -- for (unsigned I = 0; I != NumUnqualifiedLookups; ++I) { -- auto *FoundD = Record.readDeclAs(); -- auto AS = (AccessSpecifier)Record.readInt(); -- E->getTrailingObjects()[I] = -- DeclAccessPair::make(FoundD, AS); -- } -+ if (HasFirstQualifierFoundInScope) -+ *E->getTrailingObjects() = readDeclAs(); - -- if (HasTemplateInfo) -- ReadTemplateKWAndArgsInfo( -- *E->getTrailingObjects(), -- E->getTrailingObjects(), NumTemplateArgs); -+ E->MemberNameInfo = Record.readDeclarationNameInfo(); - } - - void -@@ -4076,16 +4075,16 @@ - break; - - case EXPR_CXX_DEPENDENT_SCOPE_MEMBER: { -+ unsigned NumTemplateArgs = Record[ASTStmtReader::NumExprFields]; - BitsUnpacker DependentScopeMemberBits( -- Record[ASTStmtReader::NumExprFields]); -- bool HasQualifier = DependentScopeMemberBits.getNextBit(); -- bool HasTemplateInfo = DependentScopeMemberBits.getNextBit(); -- unsigned NumUnqualifiedLookups = Record[ASTStmtReader::NumExprFields + 1]; -- unsigned NumTemplateArgs = Record[ASTStmtReader::NumExprFields + 2]; -+ Record[ASTStmtReader::NumExprFields + 1]); -+ bool HasTemplateKWAndArgsInfo = DependentScopeMemberBits.getNextBit(); - -+ bool HasFirstQualifierFoundInScope = -+ DependentScopeMemberBits.getNextBit(); - S = CXXDependentScopeMemberExpr::CreateEmpty( -- Context, HasQualifier, NumUnqualifiedLookups, HasTemplateInfo, -- NumTemplateArgs); -+ Context, HasTemplateKWAndArgsInfo, NumTemplateArgs, -+ HasFirstQualifierFoundInScope); - break; ++} // namespace targets ++} // namespace clang ++#endif // LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H +diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/OSTargets.h b/clang/lib/Basic/Targets/OSTargets.h +--- a/clang/lib/Basic/Targets/OSTargets.h ++++ b/clang/lib/Basic/Targets/OSTargets.h +@@ -841,6 +841,9 @@ + "i64:64-i128:128-n8:16:32:64-S128"); + } else if (Triple.getArch() == llvm::Triple::mipsel) { + // Handled on mips' setDataLayout. ++ } else { ++ assert(Triple.getArch() == llvm::Triple::le32); ++ this->resetDataLayout("e-p:32:32-i64:64"); + } + } + }; +diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets.cpp b/clang/lib/Basic/Targets.cpp +--- a/clang/lib/Basic/Targets.cpp ++++ b/clang/lib/Basic/Targets.cpp +@@ -23,6 +23,7 @@ + #include "Targets/DirectX.h" + #include "Targets/Hexagon.h" + #include "Targets/Lanai.h" ++#include "Targets/Le64.h" + #include "Targets/LoongArch.h" + #include "Targets/M68k.h" + #include "Targets/MSP430.h" +@@ -343,6 +344,17 @@ + return std::make_unique(Triple, Opts); } -diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ---- a/clang/lib/Serialization/ASTWriterStmt.cpp -+++ b/clang/lib/Serialization/ASTWriterStmt.cpp -@@ -1988,41 +1988,34 @@ - CXXDependentScopeMemberExpr *E) { - VisitExpr(E); - -- bool HasQualifier = E->hasQualifier(); -- unsigned NumUnqualifiedLookups = E->getNumUnqualifiedLookups(); -- bool HasTemplateInfo = E->hasTemplateKWAndArgsInfo(); -- unsigned NumTemplateArgs = E->getNumTemplateArgs(); -- -- // Write these first for easy access when deserializing, as they affect the -- // size of the CXXDependentScopeMemberExpr. -+ // Don't emit anything here (or if you do you will have to update -+ // the corresponding deserialization function). -+ Record.push_back(E->getNumTemplateArgs()); - CurrentPackingBits.updateBits(); -- CurrentPackingBits.addBit(HasQualifier); -- CurrentPackingBits.addBit(HasTemplateInfo); -- Record.push_back(NumUnqualifiedLookups); -- Record.push_back(NumTemplateArgs); -+ CurrentPackingBits.addBit(E->hasTemplateKWAndArgsInfo()); -+ CurrentPackingBits.addBit(E->hasFirstQualifierFoundInScope()); ++ case llvm::Triple::le32: ++ switch (os) { ++ case llvm::Triple::NaCl: ++ return std::make_unique>(Triple, Opts); ++ default: ++ return nullptr; ++ } + -+ if (E->hasTemplateKWAndArgsInfo()) { -+ const ASTTemplateKWAndArgsInfo &ArgInfo = -+ *E->getTrailingObjects(); -+ AddTemplateKWAndArgsInfo(ArgInfo, -+ E->getTrailingObjects()); -+ } - -- Record.AddTypeRef(E->getBaseType()); - CurrentPackingBits.addBit(E->isArrow()); ++ case llvm::Triple::le64: ++ return std::make_unique(Triple, Opts); + -+ Record.AddTypeRef(E->getBaseType()); -+ Record.AddNestedNameSpecifierLoc(E->getQualifierLoc()); - CurrentPackingBits.addBit(!E->isImplicitAccess()); - if (!E->isImplicitAccess()) - Record.AddStmt(E->getBase()); + case llvm::Triple::ppc: + switch (os) { + case llvm::Triple::Linux: +diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp +--- a/clang/lib/CodeGen/CodeGenModule.cpp ++++ b/clang/lib/CodeGen/CodeGenModule.cpp +@@ -116,6 +116,8 @@ + default: + return createDefaultTargetCodeGenInfo(CGM); - Record.AddSourceLocation(E->getOperatorLoc()); ++ case llvm::Triple::le32: ++ return createPNaClTargetCodeGenInfo(CGM); + case llvm::Triple::m68k: + return createM68kTargetCodeGenInfo(CGM); + case llvm::Triple::mips: +diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp +--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp ++++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp +@@ -576,6 +576,13 @@ + return new XLCXXABI(CGM); -- Record.AddDeclarationNameInfo(E->MemberNameInfo); -- -- if (HasQualifier) -- Record.AddNestedNameSpecifierLoc(E->getQualifierLoc()); -- -- for (DeclAccessPair D : E->unqualified_lookups()) { -- Record.AddDeclRef(D.getDecl()); -- Record.push_back(D.getAccess()); -- } -- -- if (HasTemplateInfo) -- AddTemplateKWAndArgsInfo(*E->getTrailingObjects(), -- E->getTrailingObjects()); -+ if (E->hasFirstQualifierFoundInScope()) -+ Record.AddDeclRef(E->getFirstQualifierFoundInScope()); + case TargetCXXABI::GenericItanium: ++ if (CGM.getContext().getTargetInfo().getTriple().getArch() ++ == llvm::Triple::le32) { ++ // For PNaCl, use ARM-style method pointers so that PNaCl code ++ // does not assume anything about the alignment of function ++ // pointers. ++ return new ItaniumCXXABI(CGM, /*UseARMMethodPtrABI=*/true); ++ } + return new ItaniumCXXABI(CGM); -+ Record.AddDeclarationNameInfo(E->MemberNameInfo); - Code = serialization::EXPR_CXX_DEPENDENT_SCOPE_MEMBER; + case TargetCXXABI::Microsoft: +diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp +--- a/clang/lib/Driver/ToolChains/Clang.cpp ++++ b/clang/lib/Driver/ToolChains/Clang.cpp +@@ -3815,6 +3815,12 @@ + if (UseBuiltins) + A->render(Args, CmdArgs); + } ++ ++ // le32-specific flags: ++ // -fno-math-builtin: clang should not convert math builtins to intrinsics ++ // by default. ++ if (TC.getArch() == llvm::Triple::le32) ++ CmdArgs.push_back("-fno-math-builtin"); } + bool Driver::getDefaultModuleCachePath(SmallVectorImpl &Result) { diff -ruN --strip-trailing-cr a/clang/test/CodeGen/bitfield-access-pad.c b/clang/test/CodeGen/bitfield-access-pad.c --- a/clang/test/CodeGen/bitfield-access-pad.c +++ b/clang/test/CodeGen/bitfield-access-pad.c @@ -2862,540 +269,6 @@ diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bitfield-access-tail.cpp b // RUN: %clang_cc1 -triple=loongarch32-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s // RUN: %clang_cc1 -triple=nvptx-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s // RUN: %clang_cc1 -triple=riscv32 %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s -diff -ruN --strip-trailing-cr a/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1.cpp b/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1.cpp ---- a/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1.cpp -+++ b/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1.cpp -@@ -86,19 +86,15 @@ - - template T *end(T*); - -- struct X { }; -- struct Y { -- int end; -- }; -+ class X { }; - template - void Foo2() { - T it1; -- if (it1->end < it1->end) { } -+ if (it1->end < it1->end) { -+ } - - X *x; -- if (x->end < 7) { } // expected-error{{no member named 'end' in 'PR11856::X'}} -- -- Y *y; -- if (y->end < 7) { } -+ if (x->end < 7) { // expected-error{{no member named 'end' in 'PR11856::X'}} -+ } - } - } -diff -ruN --strip-trailing-cr a/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1-cxx11.cpp b/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1-cxx11.cpp ---- a/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1-cxx11.cpp -+++ b/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1-cxx11.cpp -@@ -55,19 +55,15 @@ - - template T *end(T*); - -- struct X { }; -- struct Y { -- int end; -- }; -+ class X { }; - template - void Foo2() { - T it1; -- if (it1->end < it1->end) { } -+ if (it1->end < it1->end) { -+ } - - X *x; -- if (x->end < 7) { } // expected-error{{no member named 'end' in 'PR11856::X'}} -- -- Y *y; -- if (y->end < 7) { } -+ if (x->end < 7) { // expected-error{{no member named 'end' in 'PR11856::X'}} -+ } - } - } -diff -ruN --strip-trailing-cr a/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3.cpp b/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3.cpp ---- a/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3.cpp -+++ b/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3.cpp -@@ -1,98 +0,0 @@ --// RUN: %clang_cc1 -std=c++23 -Wno-unused %s -verify -- --namespace Unambiguous { -- struct A { -- int x; -- -- template -- using C = A; -- }; -- -- using B = A; -- -- template -- using D = A; -- -- using E = void; -- -- struct F : A { -- void non_template() { -- this->x; -- this->A::x; -- this->B::x; -- this->C::x; -- this->D::x; -- this->E::x; // expected-error {{'Unambiguous::E' (aka 'void') is not a class, namespace, or enumeration}} -- } -- }; -- -- template -- void not_instantiated(T t) { -- t.x; -- t.A::x; -- t.B::x; -- t.C::x; // expected-warning {{use 'template' keyword to treat 'C' as a dependent template name}} -- t.template C::x; -- t.D::x; // expected-warning {{use 'template' keyword to treat 'D' as a dependent template name}} -- t.template D::x; -- t.E::x; -- } -- -- template -- void instantiated_valid(T t) { -- t.x; -- t.A::x; -- t.B::x; -- t.template C::x; -- t.template D::x; -- t.E::x; -- } -- -- template -- void instantiated_invalid(T t) { -- t.x; -- t.A::x; -- t.B::x; // expected-error {{'Unambiguous::Invalid::B' (aka 'void') is not a class, namespace, or enumeration}} -- t.template C::x; -- t.template D::x; // expected-error {{'D' following the 'template' keyword does not refer to a template}} -- t.E::x; // expected-error {{'Unambiguous::E' (aka 'void') is not a class, namespace, or enumeration}} -- } -- -- struct Valid : A { -- using E = A; -- }; -- -- template void instantiated_valid(Valid); -- -- struct Invalid : A { -- using B = void; -- using D = A; // expected-note {{declared as a non-template here}} -- }; -- -- template void instantiated_invalid(Invalid); // expected-note {{in instantiation of}} --} // namespace Unambiguous -- --namespace Ambiguous { -- inline namespace N { -- struct A { }; // expected-note {{candidate found by name lookup is 'Ambiguous::N::A'}} -- } -- -- struct A { }; // expected-note {{candidate found by name lookup is 'Ambiguous::A'}} -- -- template -- void f(T t) { -- t.A::x; // expected-error {{reference to 'A' is ambiguous}} -- } -- -- struct B { -- using A = B; -- -- int x; -- }; -- -- struct C { }; -- -- template void f(B); -- template void f(C); // expected-note {{in instantiation of}} -- --} // namespace Ambiguous -diff -ruN --strip-trailing-cr a/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3-example3.cpp b/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3-example3.cpp ---- a/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3-example3.cpp -+++ b/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3-example3.cpp -@@ -1,27 +0,0 @@ --// RUN: %clang_cc1 -std=c++23 %s -verify -- --int f(); -- --struct A { -- int B, C; // expected-note {{declared as a non-template here}} -- template using D = void; -- using T = void; -- void f(); --}; -- --using B = A; --template using C = A; --template using D = A; --template using X = A; -- --template --void g(T *p) { -- p->X<0>::f(); // expected-error {{no member named 'X' in 'A'}} -- p->template X<0>::f(); -- p->B::f(); -- p->template C<0>::f(); // expected-error {{'C' following the 'template' keyword does not refer to a template}} -- p->template D<0>::f(); // expected-error {{type 'template D<0>' (aka 'void') cannot be used prior to '::' because it has no members}} -- p->T::f(); // expected-error {{'A::T' (aka 'void') is not a class, namespace, or enumeration}} --} -- --template void g(A*); // expected-note {{in instantiation of}} -diff -ruN --strip-trailing-cr a/clang/test/CXX/class.derived/class.member.lookup/p8.cpp b/clang/test/CXX/class.derived/class.member.lookup/p8.cpp ---- a/clang/test/CXX/class.derived/class.member.lookup/p8.cpp -+++ b/clang/test/CXX/class.derived/class.member.lookup/p8.cpp -@@ -47,8 +47,8 @@ - void DerivedT::Inner() { - Derived1T::Foo(); - Derived2T::Member = 42; -- this->Derived1T::Foo(); // expected-warning{{use 'template' keyword to treat 'Derived1T' as a dependent template name}} -- this->Derived2T::Member = 42; // expected-warning{{use 'template' keyword to treat 'Derived2T' as a dependent template name}} -+ this->Derived1T::Foo(); -+ this->Derived2T::Member = 42; - this->Foo(); // expected-error{{non-static member 'Foo' found in multiple base-class subobjects of type 'BaseT'}} - } - -diff -ruN --strip-trailing-cr a/clang/test/CXX/drs/cwg1xx.cpp b/clang/test/CXX/drs/cwg1xx.cpp ---- a/clang/test/CXX/drs/cwg1xx.cpp -+++ b/clang/test/CXX/drs/cwg1xx.cpp -@@ -615,8 +615,10 @@ - // cxx98-note@#cwg141-S {{lookup from the current scope refers here}} - // expected-error@#cwg141-a {{no member named 'n' in 'cwg141::A::S'; did you mean '::cwg141::S::n'?}} - // expected-note@#cwg141-S {{'::cwg141::S::n' declared here}} -+ // FIXME: we issue a useful diagnostic first, then some bogus ones. - b.f(); - // expected-error@-1 {{no member named 'f' in 'cwg141::B'}} -+ // expected-error@-2 +{{}} - (void)b.S::n; - } - template struct C { -@@ -626,12 +628,10 @@ - // expected-error@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} - } - void h() { -- (void)t.S::n; -- // expected-error@-1 {{use 'template' keyword to treat 'S' as a dependent template name}} -+ (void)t.S::n; // ok - } - void i() { -- (void)t.S(); -- // expected-error@-1 {{use 'template' keyword to treat 'S' as a dependent template name}} -+ (void)t.S(); // ok! - } - }; - void h() { C().h(); } // ok -diff -ruN --strip-trailing-cr a/clang/test/CXX/temp/temp.names/p3-23.cpp b/clang/test/CXX/temp/temp.names/p3-23.cpp ---- a/clang/test/CXX/temp/temp.names/p3-23.cpp -+++ b/clang/test/CXX/temp/temp.names/p3-23.cpp -@@ -1,237 +0,0 @@ --// RUN: %clang_cc1 -std=c++23 -Wno-unused %s -verify -- --namespace FoundNothing { -- template -- void f0(T &t) { -- t.x<0; -- t.x<0>; // expected-error {{expected expression}} -- t.x<0>1; -- } -- -- template -- struct A { -- void f1() { -- this->x<0; // expected-error {{no member named 'x' in 'A'}} -- this->x<0>; // expected-error {{no member named 'x' in 'A'}} -- // expected-error@-1 {{expected expression}} -- this->x<0>1; // expected-error {{no member named 'x' in 'A'}} -- } -- }; --} // namespace FoundNothing -- --namespace FoundSingleNonTemplate { -- void f0(); -- -- struct A0; -- -- template -- void g0(T &t) { -- t.f0<0; -- t.f0<0>; // expected-error {{expected expression}} -- t.f0<0>1; -- -- t.A0<0; -- t.A0<0>; // expected-error {{expected expression}} -- t.A0<0>1; -- } -- -- template -- struct B { -- void f1(); -- -- struct A1; // expected-note 3{{member 'A1' declared here}} -- -- void g1() { -- this->f0<0; // expected-error {{no member named 'f0' in 'B'}} -- this->f0<0>; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{expected expression}} -- this->f0<0>1; // expected-error {{no member named 'f0' in 'B'}} -- -- this->A0<0; // expected-error {{no member named 'A0' in 'B'}} -- this->A0<0>; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{expected expression}} -- this->A0<0>1; // expected-error {{no member named 'A0' in 'B'}} -- -- this->f1<0; // expected-error {{reference to non-static member function must be called}} -- this->f1<0>; // expected-error {{reference to non-static member function must be called}} -- // expected-error@-1 {{expected expression}} -- this->f1<0>1; // expected-error {{reference to non-static member function must be called}} -- -- this->A1<0; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- this->A1<0>; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- // expected-error@-1 {{expected expression}} -- this->A1<0>1; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- } -- }; --} // namespace FoundSingleNonTemplate -- --namespace FoundSingleTemplate { -- template -- void f0(); -- -- template -- struct A0; -- -- template -- void g0(T &t) { -- t.f0<0; -- t.f0<0>; // expected-error {{expected expression}} -- t.f0<0>1; -- -- t.A0<0; -- t.A0<0>; // expected-error {{expected expression}} -- t.A0<0>1; -- } -- -- template -- struct B { -- template -- void f1(); // expected-note 2{{possible target for call}} -- -- template -- struct A1; // expected-note 2{{member 'A1' declared here}} -- -- void g1() { -- this->f0<0; // expected-error {{no member named 'f0' in 'B'}} -- this->f0<0>; // expected-error {{no member named 'f0' in 'B'}} -- this->f0<0>1; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{expected ';' after expression}} -- -- this->A0<0; // expected-error {{no member named 'A0' in 'B'}} -- this->A0<0>; // expected-error {{no member named 'A0' in 'B'}} -- this->A0<0>1; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{expected ';' after expression}} -- -- -- this->f1<0; // expected-error {{expected '>'}} -- // expected-note@-1 {{to match this '<'}} -- this->f1<0>; // expected-error {{reference to non-static member function must be called}} -- this->f1<0>1; // expected-error {{reference to non-static member function must be called}} -- // expected-error@-1 {{expected ';' after expression}} -- -- this->A1<0; // expected-error {{expected '>'}} -- // expected-note@-1 {{to match this '<'}} -- this->A1<0>; // expected-error {{cannot refer to member 'A1' in 'B' with '->'}} -- this->A1<0>1; // expected-error {{cannot refer to member 'A1' in 'B' with '->'}} -- // expected-error@-1 {{expected ';' after expression}} -- } -- }; --} // namespace FoundSingleTemplate -- --namespace FoundAmbiguousNonTemplate { -- inline namespace N { -- int f0; -- -- struct A0; -- } // namespace N -- -- void f0(); -- -- struct A0; -- -- template -- void g0(T &t) { -- t.f0<0; -- t.f0<0>; // expected-error {{expected expression}} -- t.f0<0>1; -- -- t.A0<0; -- t.A0<0>; // expected-error {{expected expression}} -- t.A0<0>1; -- } -- -- template -- struct B { -- void f1(); -- -- struct A1; // expected-note 3{{member 'A1' declared here}} -- -- void g1() { -- this->f0<0; // expected-error {{no member named 'f0' in 'B'}} -- this->f0<0>; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{expected expression}} -- this->f0<0>1; // expected-error {{no member named 'f0' in 'B'}} -- -- this->A0<0; // expected-error {{no member named 'A0' in 'B'}} -- this->A0<0>; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{expected expression}} -- this->A0<0>1; // expected-error {{no member named 'A0' in 'B'}} -- -- this->f1<0; // expected-error {{reference to non-static member function must be called}} -- this->f1<0>; // expected-error {{reference to non-static member function must be called}} -- // expected-error@-1 {{expected expression}} -- this->f1<0>1; // expected-error {{reference to non-static member function must be called}} -- -- this->A1<0; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- this->A1<0>; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- // expected-error@-1 {{expected expression}} -- this->A1<0>1; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- } -- }; --} // namespace FoundAmbiguousNonTemplates -- --namespace FoundAmbiguousTemplate { -- inline namespace N { -- template -- int f0; // expected-note 3{{candidate found by name lookup is 'FoundAmbiguousTemplate::N::f0'}} -- -- template -- struct A0; // expected-note 3{{candidate found by name lookup is 'FoundAmbiguousTemplate::N::A0'}} -- } // namespace N -- -- template -- void f0(); // expected-note 3{{candidate found by name lookup is 'FoundAmbiguousTemplate::f0'}} -- -- template -- struct A0; // expected-note 3{{candidate found by name lookup is 'FoundAmbiguousTemplate::A0'}} -- -- template -- void g0(T &t) { -- t.f0<0; -- t.f0<0>; // expected-error {{expected expression}} -- t.f0<0>1; -- -- t.A0<0; -- t.A0<0>; // expected-error {{expected expression}} -- t.A0<0>1; -- } -- -- template -- struct B { -- template -- void f1(); // expected-note 2{{possible target for call}} -- -- template -- struct A1; // expected-note 2{{member 'A1' declared here}} -- -- void g1() { -- this->f0<0; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{reference to 'f0' is ambiguous}} -- this->f0<0>; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{reference to 'f0' is ambiguous}} -- this->f0<0>1; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{expected ';' after expression}} -- // expected-error@-2 {{reference to 'f0' is ambiguous}} -- -- this->A0<0; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{reference to 'A0' is ambiguous}} -- this->A0<0>; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{reference to 'A0' is ambiguous}} -- this->A0<0>1; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{expected ';' after expression}} -- // expected-error@-2 {{reference to 'A0' is ambiguous}} -- -- this->f1<0; // expected-error {{expected '>'}} -- // expected-note@-1 {{to match this '<'}} -- this->f1<0>; // expected-error {{reference to non-static member function must be called}} -- this->f1<0>1; // expected-error {{reference to non-static member function must be called}} -- // expected-error@-1 {{expected ';' after expression}} -- -- this->A1<0; // expected-error {{expected '>'}} -- // expected-note@-1 {{to match this '<'}} -- this->A1<0>; // expected-error {{cannot refer to member 'A1' in 'B' with '->'}} -- this->A1<0>1; // expected-error {{cannot refer to member 'A1' in 'B' with '->'}} -- // expected-error@-1 {{expected ';' after expression}} -- } -- }; --} // namespace FoundAmbiguousTemplate -diff -ruN --strip-trailing-cr a/clang/test/CXX/temp/temp.res/p3.cpp b/clang/test/CXX/temp/temp.res/p3.cpp ---- a/clang/test/CXX/temp/temp.res/p3.cpp -+++ b/clang/test/CXX/temp/temp.res/p3.cpp -@@ -30,6 +30,6 @@ - template template struct A::B { - friend A::C f6(); // ok, same as 'friend T f6();' - -- friend A::C f7(); // expected-warning {{use 'template' keyword to treat 'C' as a dependent template name}} expected-warning {{missing 'typename'}} -+ friend A::C f7(); // expected-error {{use 'template' keyword to treat 'C' as a dependent template name}} expected-warning {{missing 'typename'}} - friend A::template C f8(); // expected-warning {{missing 'typename'}} - }; -diff -ruN --strip-trailing-cr a/clang/test/FixIt/fixit.cpp b/clang/test/FixIt/fixit.cpp ---- a/clang/test/FixIt/fixit.cpp -+++ b/clang/test/FixIt/fixit.cpp -@@ -158,12 +158,12 @@ - - template - class F2 { -- typename F1:: /*template*/ Iterator<0> Mypos; // expected-warning {{use 'template' keyword to treat 'Iterator' as a dependent template name}} -+ typename F1:: /*template*/ Iterator<0> Mypos; // expected-error {{use 'template' keyword to treat 'Iterator' as a dependent template name}} - }; - - template - void f(){ -- typename F1:: /*template*/ Iterator<0> Mypos; // expected-warning {{use 'template' keyword to treat 'Iterator' as a dependent template name}} -+ typename F1:: /*template*/ Iterator<0> Mypos; // expected-error {{use 'template' keyword to treat 'Iterator' as a dependent template name}} - } - - // Tests for &/* fixits -diff -ruN --strip-trailing-cr a/clang/test/Misc/warning-flags.c b/clang/test/Misc/warning-flags.c ---- a/clang/test/Misc/warning-flags.c -+++ b/clang/test/Misc/warning-flags.c -@@ -18,7 +18,7 @@ - - The list of warnings below should NEVER grow. It should gradually shrink to 0. - --CHECK: Warnings without flags (64): -+CHECK: Warnings without flags (65): - - CHECK-NEXT: ext_expected_semi_decl_list - CHECK-NEXT: ext_missing_whitespace_after_macro_name -@@ -61,6 +61,7 @@ - CHECK-NEXT: warn_maynot_respond - CHECK-NEXT: warn_method_param_redefinition - CHECK-NEXT: warn_missing_case_for_condition -+CHECK-NEXT: warn_missing_dependent_template_keyword - CHECK-NEXT: warn_missing_whitespace_after_macro_name - CHECK-NEXT: warn_mt_message - CHECK-NEXT: warn_no_constructor_for_refconst -diff -ruN --strip-trailing-cr a/clang/test/Parser/cxx2a-concepts-requires-expr.cpp b/clang/test/Parser/cxx2a-concepts-requires-expr.cpp ---- a/clang/test/Parser/cxx2a-concepts-requires-expr.cpp -+++ b/clang/test/Parser/cxx2a-concepts-requires-expr.cpp -@@ -78,7 +78,7 @@ - - template - bool r23 = requires { typename identity::temp; }; --// expected-warning@-1 {{use 'template' keyword to treat 'temp' as a dependent template name}} -+// expected-error@-1 {{use 'template' keyword to treat 'temp' as a dependent template name}} - - template - bool r24 = requires { diff -ruN --strip-trailing-cr a/clang/test/Preprocessor/predefined-macros-no-warnings.c b/clang/test/Preprocessor/predefined-macros-no-warnings.c --- a/clang/test/Preprocessor/predefined-macros-no-warnings.c +++ b/clang/test/Preprocessor/predefined-macros-no-warnings.c @@ -3408,288 +281,67 @@ diff -ruN --strip-trailing-cr a/clang/test/Preprocessor/predefined-macros-no-war // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc-freebsd // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc-netbsd -diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/cxx0x-noexcept-expression.cpp b/clang/test/SemaCXX/cxx0x-noexcept-expression.cpp ---- a/clang/test/SemaCXX/cxx0x-noexcept-expression.cpp -+++ b/clang/test/SemaCXX/cxx0x-noexcept-expression.cpp -@@ -127,7 +127,7 @@ - // `dependent` should be type-dependent because the noexcept-expression should be value-dependent - // (it is true if T is int*, false if T is Polymorphic* for example) - dependent.f(); // This should need to be `.template f` to parse as a template -- // expected-warning@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} -+ // expected-error@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} - } - template - void f2() { -@@ -135,14 +135,14 @@ - // X when T...[0] is a type with some operator&& which returns int* - // X when sizeof...(T) == 0 - dependent.f(); -- // expected-warning@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} -+ // expected-error@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} - } - template - void f3() { - X(nullptr)))> dependent; - // X when T is int, X when T is Polymorphic - dependent.f(); -- // expected-warning@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} -+ // expected-error@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} +diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h +--- a/llvm/include/llvm/IR/PatternMatch.h ++++ b/llvm/include/llvm/IR/PatternMatch.h +@@ -1550,27 +1550,23 @@ + template + struct CmpClass_match { +- PredicateTy *Predicate; ++ PredicateTy &Predicate; + LHS_t L; + RHS_t R; + + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. + CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) +- : Predicate(&Pred), L(LHS), R(RHS) {} +- CmpClass_match(const LHS_t &LHS, const RHS_t &RHS) +- : Predicate(nullptr), L(LHS), R(RHS) {} ++ : Predicate(Pred), L(LHS), R(RHS) {} + + template bool match(OpTy *V) { + if (auto *I = dyn_cast(V)) { + if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { +- if (Predicate) +- *Predicate = I->getPredicate(); ++ Predicate = I->getPredicate(); + return true; + } else if (Commutable && L.match(I->getOperand(1)) && + R.match(I->getOperand(0))) { +- if (Predicate) +- *Predicate = I->getSwappedPredicate(); ++ Predicate = I->getSwappedPredicate(); + return true; + } + } +@@ -1599,19 +1595,22 @@ + template + inline CmpClass_match + m_Cmp(const LHS &L, const RHS &R) { +- return CmpClass_match(L, R); ++ CmpInst::Predicate Unused; ++ return CmpClass_match(Unused, L, R); } - template - void f4() { -diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/pseudo-destructors.cpp b/clang/test/SemaCXX/pseudo-destructors.cpp ---- a/clang/test/SemaCXX/pseudo-destructors.cpp -+++ b/clang/test/SemaCXX/pseudo-destructors.cpp -@@ -22,21 +22,21 @@ - void f(A* a, Foo *f, int *i, double *d, int ii) { - a->~A(); - a->A::~A(); -- -+ - a->~foo(); // expected-error{{undeclared identifier 'foo' in destructor name}} -- -+ - a->~Bar(); // expected-error{{destructor type 'Bar' (aka 'Foo') in object destruction expression does not match the type 'A' of the object being destroyed}} -- -+ - f->~Bar(); - f->~Foo(); - i->~Bar(); // expected-error{{does not match}} -- -+ - g().~Bar(); // expected-error{{non-scalar}} -- -+ - f->::~Bar(); // expected-error {{not a structure or union}} - f->::Bar::~Bar(); - f->N::~Wibble(); // expected-error{{'N' does not refer to a type}} expected-error{{'Wibble' does not refer to a type}} -- -+ - f->Bar::~Bar(17, 42); // expected-error{{cannot have any arguments}} - - i->~Integer(); -@@ -148,12 +148,12 @@ - namespace Template { - template struct Y {}; - template using G = Y; -- template void f(T *p) { p->~G(); } // expected-error {{no member named 'G'}} -+ template void f(T *p) { p->~G(); } // expected-error {{no member named '~Y'}} - void h1(Y *p) { p->~G(); } -- void h2(Y *p) { f(p); } // expected-note {{instantiation of}} -+ void h2(Y *p) { f(p); } - namespace N { template struct G {}; } - void h3(N::G *p) { p->~G(); } -- void h4(N::G *p) { f(p); } -+ void h4(N::G *p) { f(p); } // expected-note {{instantiation of}} - } - - namespace TemplateUndeclared { -diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/static-assert-cxx17.cpp b/clang/test/SemaCXX/static-assert-cxx17.cpp ---- a/clang/test/SemaCXX/static-assert-cxx17.cpp -+++ b/clang/test/SemaCXX/static-assert-cxx17.cpp -@@ -96,7 +96,7 @@ - // expected-error@-1{{static assertion failed due to requirement 'static_cast *>(nullptr)'}} - static_assert((const X[]){} == nullptr); - // expected-error@-1{{static assertion failed due to requirement '(const X[0]){} == nullptr'}} -- static_assert(sizeof(X().template X::~X())>) == 0); -+ static_assert(sizeof(X().X::~X())>) == 0); - // expected-error@-1{{static assertion failed due to requirement 'sizeof(X) == 0'}} \ - // expected-note@-1 {{evaluates to '8 == 0'}} - static_assert(constexpr_return_false()); -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-base-classes.cpp b/clang/test/SemaTemplate/dependent-base-classes.cpp ---- a/clang/test/SemaTemplate/dependent-base-classes.cpp -+++ b/clang/test/SemaTemplate/dependent-base-classes.cpp -@@ -1,12 +1,12 @@ - // RUN: %clang_cc1 -fsyntax-only -verify %s - - template --struct X0 : T::template apply { -+struct X0 : T::template apply { - X0(U u) : T::template apply(u) { } - }; - template --struct X1 : T::apply { }; // expected-warning{{use 'template' keyword to treat 'apply' as a dependent template name}} -+struct X1 : T::apply { }; // expected-error{{use 'template' keyword to treat 'apply' as a dependent template name}} - - template - struct X2 : vector { }; // expected-error{{no template named 'vector'}} -@@ -85,7 +85,7 @@ - struct A { }; - - template -- class B : public A -+ class B : public A - { - public: - template< class X > -@@ -109,9 +109,9 @@ - - namespace PR6413 { - template class Base_A { }; -- -+ - class Base_B { }; -- -+ - template - class Derived - : public virtual Base_A -@@ -120,12 +120,12 @@ + template + inline CmpClass_match + m_ICmp(const LHS &L, const RHS &R) { +- return CmpClass_match(L, R); ++ ICmpInst::Predicate Unused; ++ return CmpClass_match(Unused, L, R); } - namespace PR5812 { -- template struct Base { -- Base* p; -- }; -+ template struct Base { -+ Base* p; -+ }; - -- template struct Derived: public Base { -- typename Derived::Base* p; // meaning Derived::Base -+ template struct Derived: public Base { -+ typename Derived::Base* p; // meaning Derived::Base - }; - - Derived di; -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-template-recover.cpp b/clang/test/SemaTemplate/dependent-template-recover.cpp ---- a/clang/test/SemaTemplate/dependent-template-recover.cpp -+++ b/clang/test/SemaTemplate/dependent-template-recover.cpp -@@ -2,15 +2,15 @@ - template - struct X { - void f(T* t) { -- t->f0(); // expected-warning{{use 'template' keyword to treat 'f0' as a dependent template name}} -- t->f0(); // expected-warning{{use 'template' keyword to treat 'f0' as a dependent template name}} -+ t->f0(); // expected-error{{use 'template' keyword to treat 'f0' as a dependent template name}} -+ t->f0(); // expected-error{{use 'template' keyword to treat 'f0' as a dependent template name}} - -- t->operator+(1); // expected-warning{{use 'template' keyword to treat 'operator +' as a dependent template name}} -- t->f1(1); // expected-warning{{use 'template' keyword to treat 'f1' as a dependent template name}} -+ t->operator+(1); // expected-error{{use 'template' keyword to treat 'operator +' as a dependent template name}} -+ t->f1(1); // expected-error{{use 'template' keyword to treat 'f1' as a dependent template name}} - t->f1<3, int const>(1); // expected-error{{missing 'template' keyword prior to dependent template name 'f1'}} - -- T::getAs(); // expected-warning{{use 'template' keyword to treat 'getAs' as a dependent template name}} -- t->T::getAs(); // expected-warning{{use 'template' keyword to treat 'getAs' as a dependent template name}} -+ T::getAs(); // expected-error{{use 'template' keyword to treat 'getAs' as a dependent template name}} -+ t->T::getAs(); // expected-error{{use 'template' keyword to treat 'getAs' as a dependent template name}} - - (*t).f2(); // expected-error{{missing 'template' keyword prior to dependent template name 'f2'}} - (*t).f2<0>(); // expected-error{{missing 'template' keyword prior to dependent template name 'f2'}} -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/temp_arg_nontype_cxx20.cpp b/clang/test/SemaTemplate/temp_arg_nontype_cxx20.cpp ---- a/clang/test/SemaTemplate/temp_arg_nontype_cxx20.cpp -+++ b/clang/test/SemaTemplate/temp_arg_nontype_cxx20.cpp -@@ -115,7 +115,7 @@ - static_assert(f(X()) == 0); - - template struct Y { void f(); }; -- template void g(Y y) { y.template Y::f(); } -+ template void g(Y y) { y.Y::f(); } - void h() { constexpr A a; g(Y{}); } - - template struct Z { -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/template-id-expr.cpp b/clang/test/SemaTemplate/template-id-expr.cpp ---- a/clang/test/SemaTemplate/template-id-expr.cpp -+++ b/clang/test/SemaTemplate/template-id-expr.cpp -@@ -19,7 +19,7 @@ - struct X0 { - template - void f1(); -- -+ - template - void f2(U) { - f1(); -@@ -39,9 +39,9 @@ - template - struct X { - X(int, int); -- void f() { -- Y >(X(0, 0)); -- Y >(::X(0, 0)); -+ void f() { -+ Y >(X(0, 0)); -+ Y >(::X(0, 0)); - } - }; - -@@ -149,11 +149,11 @@ - - int x; - x = Y1::f4(0); -- x = Y1::f4(0); // expected-warning {{use 'template'}} expected-error {{assigning to 'int' from incompatible type 'void'}} -+ x = Y1::f4(0); // expected-error {{use 'template'}} expected-error {{assigning to 'int' from incompatible type 'void'}} - x = Y1::template f4(0); // expected-error {{assigning to 'int' from incompatible type 'void'}} expected-error {{a template argument list is expected after a name prefixed by the template keyword}} - - x = p->f4(0); -- x = p->f4(0); // expected-error {{assigning to 'int' from incompatible type 'void'}} expected-warning {{use 'template'}} -+ x = p->f4(0); // expected-error {{assigning to 'int' from incompatible type 'void'}} expected-error {{use 'template'}} - x = p->template f4(0); // expected-error {{assigning to 'int' from incompatible type 'void'}} expected-error {{a template argument list is expected after a name prefixed by the template keyword}} - } - }; -@@ -184,7 +184,7 @@ - #if __cplusplus <= 199711L - // expected-warning@+2 {{extension}} - #endif --template using D = int; // expected-note {{declared here}} -+template using D = int; // expected-note {{declared here}} - E ed; // expected-note {{instantiation of}} - - namespace non_functions { -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/typename-specifier-3.cpp b/clang/test/SemaTemplate/typename-specifier-3.cpp ---- a/clang/test/SemaTemplate/typename-specifier-3.cpp -+++ b/clang/test/SemaTemplate/typename-specifier-3.cpp -@@ -46,7 +46,7 @@ - typedef int arg; - }; - struct C { -- typedef typename B::X x; // expected-warning {{use 'template'}} expected-error {{refers to non-type}} -+ typedef typename B::X x; // expected-error {{use 'template'}} expected-error {{refers to non-type}} - }; - }; - -diff -ruN --strip-trailing-cr a/libc/src/__support/macros/config.h b/libc/src/__support/macros/config.h ---- a/libc/src/__support/macros/config.h -+++ b/libc/src/__support/macros/config.h -@@ -15,7 +15,6 @@ - - // Workaround for compilers that do not support builtin detection. - // FIXME: This is only required for the GPU portion which should be moved. --#include "src/__support/macros/config.h" - #ifndef __has_builtin - #define __has_builtin(b) 0 - #endif -diff -ruN --strip-trailing-cr a/libcxx/include/regex b/libcxx/include/regex ---- a/libcxx/include/regex -+++ b/libcxx/include/regex -@@ -4214,7 +4214,7 @@ - _LIBCPP_HIDE_FROM_ABI int compare(const value_type* __s) const { return str().compare(__s); } - - _LIBCPP_HIDE_FROM_ABI void swap(sub_match& __s) _NOEXCEPT_(__is_nothrow_swappable_v<_BidirectionalIterator>) { -- this->template pair<_BidirectionalIterator, _BidirectionalIterator>::swap(__s); -+ this->pair<_BidirectionalIterator, _BidirectionalIterator>::swap(__s); - std::swap(matched, __s.matched); - } - }; -diff -ruN --strip-trailing-cr a/llvm/include/llvm/ADT/ArrayRef.h b/llvm/include/llvm/ADT/ArrayRef.h ---- a/llvm/include/llvm/ADT/ArrayRef.h -+++ b/llvm/include/llvm/ADT/ArrayRef.h -@@ -460,11 +460,8 @@ - - OwningArrayRef &operator=(OwningArrayRef &&Other) { - delete[] this->data(); -- using Base = MutableArrayRef; -- // GCC versions prior to 11.1 incorrectly reject if the 'template' keyword -- // is used prior to the nested-name-specifier here. -- this->Base::operator=(Other); -- Other.Base::operator=(Base()); -+ this->MutableArrayRef::operator=(Other); -+ Other.MutableArrayRef::operator=(MutableArrayRef()); - return *this; - } + template + inline CmpClass_match + m_FCmp(const LHS &L, const RHS &R) { +- return CmpClass_match(L, R); ++ FCmpInst::Predicate Unused; ++ return CmpClass_match(Unused, L, R); + } + // Same as CmpClass, but instead of saving Pred as out output variable, match a diff -ruN --strip-trailing-cr a/llvm/include/llvm/TargetParser/Triple.h b/llvm/include/llvm/TargetParser/Triple.h --- a/llvm/include/llvm/TargetParser/Triple.h +++ b/llvm/include/llvm/TargetParser/Triple.h @@ -3702,18 +354,6 @@ diff -ruN --strip-trailing-cr a/llvm/include/llvm/TargetParser/Triple.h b/llvm/i amdil, // AMDIL amdil64, // AMDIL with 64-bit pointers hsail, // AMD HSAIL -diff -ruN --strip-trailing-cr a/llvm/lib/CodeGen/MachineSink.cpp b/llvm/lib/CodeGen/MachineSink.cpp ---- a/llvm/lib/CodeGen/MachineSink.cpp -+++ b/llvm/lib/CodeGen/MachineSink.cpp -@@ -961,7 +961,7 @@ - MachineBasicBlock *ToBB, - bool BreakPHIEdge) { - // Avoid breaking back edge. From == To means backedge for single BB cycle. -- if (!SplitEdges || FromBB == ToBB) -+ if (!SplitEdges || FromBB == ToBB || !FromBB->isSuccessor(ToBB)) - return false; - - MachineCycle *FromCycle = CI->getCycle(FromBB); diff -ruN --strip-trailing-cr a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp --- a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp +++ b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp @@ -4008,86 +648,6 @@ diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll ; CHECK-NEXT: ret entry: %ext64 = load i32, ptr %x0 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/MachineSink-Issue98477.ll b/llvm/test/CodeGen/X86/MachineSink-Issue98477.ll ---- a/llvm/test/CodeGen/X86/MachineSink-Issue98477.ll -+++ b/llvm/test/CodeGen/X86/MachineSink-Issue98477.ll -@@ -0,0 +1,76 @@ -+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -+; RUN: llc < %s | FileCheck %s -+ -+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" -+target triple = "x86_64-unknown-linux-gnu" -+ -+define i32 @main(i1 %tobool.not, i32 %0) { -+; CHECK-LABEL: main: -+; CHECK: # %bb.0: # %entry -+; CHECK-NEXT: movl $1, %r8d -+; CHECK-NEXT: testb $1, %dil -+; CHECK-NEXT: jne .LBB0_8 -+; CHECK-NEXT: .LBB0_1: # %j.preheader -+; CHECK-NEXT: xorl %r9d, %r9d -+; CHECK-NEXT: jmp .LBB0_2 -+; CHECK-NEXT: .p2align 4, 0x90 -+; CHECK-NEXT: .LBB0_5: # %if.then4 -+; CHECK-NEXT: # in Loop: Header=BB0_2 Depth=1 -+; CHECK-NEXT: movl $1, %eax -+; CHECK-NEXT: xorl %edx, %edx -+; CHECK-NEXT: divl %r8d -+; CHECK-NEXT: testb $1, %dil -+; CHECK-NEXT: jne .LBB0_6 -+; CHECK-NEXT: .LBB0_2: # %j -+; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -+; CHECK-NEXT: movl $1, %eax -+; CHECK-NEXT: xorl %edx, %edx -+; CHECK-NEXT: idivl %esi -+; CHECK-NEXT: movl %edx, %ecx -+; CHECK-NEXT: testb %r9b, %r9b -+; CHECK-NEXT: jne .LBB0_5 -+; CHECK-NEXT: # %bb.3: # %j -+; CHECK-NEXT: # in Loop: Header=BB0_2 Depth=1 -+; CHECK-NEXT: testl %r9d, %r9d -+; CHECK-NEXT: js .LBB0_5 -+; CHECK-NEXT: # %bb.4: -+; CHECK-NEXT: movl %r9d, %edx -+; CHECK-NEXT: .LBB0_6: # %if.end9 -+; CHECK-NEXT: testl %edx, %edx -+; CHECK-NEXT: jne .LBB0_7 -+; CHECK-NEXT: .LBB0_8: # %if.end13 -+; CHECK-NEXT: xorl %r8d, %r8d -+; CHECK-NEXT: jmp .LBB0_1 -+; CHECK-NEXT: .LBB0_7: # %while.body.lr.ph -+; CHECK-NEXT: movl %ecx, %eax -+; CHECK-NEXT: retq -+entry: -+ br i1 %tobool.not, label %if.end13, label %j.preheader -+ -+ j.preheader: ; preds = %if.end13, %entry -+ %h.0.ph = phi i32 [ 1, %entry ], [ 0, %if.end13 ] -+ br label %j -+ -+ j: ; preds = %if.then4, %j.preheader -+ %1 = phi i32 [ %div2, %if.then4 ], [ 0, %j.preheader ] -+ %rem1 = srem i32 1, %0 -+ %cmp = icmp slt i32 %1, 0 -+ %or.cond = select i1 false, i1 true, i1 %cmp -+ br i1 %or.cond, label %if.then4, label %if.end9 -+ -+ if.then4: ; preds = %j -+ %div2 = sdiv i32 1, 0 -+ %rem5 = srem i32 1, %h.0.ph -+ br i1 %tobool.not, label %if.end9, label %j -+ -+ if.end9: ; preds = %if.then4, %j -+ %2 = phi i32 [ 0, %j ], [ %rem5, %if.then4 ] -+ %tobool10.not = icmp eq i32 %2, 0 -+ br i1 %tobool10.not, label %if.end13, label %while.body.lr.ph -+ -+ while.body.lr.ph: ; preds = %if.end9 -+ ret i32 %rem1 -+ -+ if.end13: ; preds = %if.end9, %entry -+ br label %j.preheader -+} diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll b/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll --- a/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll +++ b/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll @@ -4524,6 +1084,27 @@ diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/X86/interleav +; CHECK: [[LOOP9]] = distinct !{[[LOOP9]], [[META1]], [[META2]]} +; CHECK: [[LOOP10]] = distinct !{[[LOOP10]], [[META1]]} ;. +diff -ruN --strip-trailing-cr a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp +--- a/llvm/unittests/IR/PatternMatch.cpp ++++ b/llvm/unittests/IR/PatternMatch.cpp +@@ -2235,7 +2235,7 @@ + MutableConstTestTypes; + TYPED_TEST_SUITE(MutableConstTest, MutableConstTestTypes, ); + +-TYPED_TEST(MutableConstTest, ICmp) { ++TYPED_TEST(MutableConstTest, /* FIXME: UAR bug */ DISABLED_ICmp) { + auto &IRB = PatternMatchTest::IRB; + + typedef std::tuple_element_t<0, TypeParam> ValueType; +@@ -2319,7 +2319,7 @@ + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + } + +-TYPED_TEST(MutableConstTest, FCmp) { ++TYPED_TEST(MutableConstTest, /* FIXME: UAR bug */ DISABLED_FCmp) { + auto &IRB = PatternMatchTest::IRB; + + typedef std::tuple_element_t<0, TypeParam> ValueType; diff -ruN --strip-trailing-cr a/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn b/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn --- a/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn +++ b/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index b93276bcc4c566..6c8da928bb4d1a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "5ff3ff33ff930e4ec49da7910612d8a41eb068cb" - LLVM_SHA256 = "15fd6dcf22fdf549831d8d490970f66965988f1116dcc4ac04ab2570d9399aba" + LLVM_COMMIT = "dd7d81ea49bf39e1d69bbb84bd3f31bd95519369" + LLVM_SHA256 = "fbd43ef20f4209b0619e209e48c431f76008917714a8c5336063e1ff51d8d084" tf_http_archive( name = name, diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index d6f26a04468fd2..ed3d58f027f90b 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -11,17 +11,6 @@ diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/Release C/C++ Language Potentially Breaking Changes ------------------------------------------- -@@ -313,10 +311,6 @@ - - Clang now considers ``noexcept(typeid(expr))`` more carefully, instead of always assuming that ``std::bad_typeid`` can be thrown. - (`CWG2191: Incorrect result for noexcept(typeid(v)) `_). - --- Clang now correctly implements lookup for the terminal name of a member-qualified nested-name-specifier. -- (`CWG1835: Dependent member lookup before < `_). -- The warning can be disabled via `-Wno-missing-dependent-template-keyword`. -- - C Language Changes - ------------------ - diff -ruN --strip-trailing-cr a/clang/docs/tools/clang-formatted-files.txt b/clang/docs/tools/clang-formatted-files.txt --- a/clang/docs/tools/clang-formatted-files.txt +++ b/clang/docs/tools/clang-formatted-files.txt @@ -33,642 +22,6 @@ diff -ruN --strip-trailing-cr a/clang/docs/tools/clang-formatted-files.txt b/cla clang/lib/Basic/Targets/M68k.h clang/lib/Basic/Targets/MSP430.h clang/lib/Basic/Targets/NVPTX.cpp -diff -ruN --strip-trailing-cr a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h ---- a/clang/include/clang/AST/ExprCXX.h -+++ b/clang/include/clang/AST/ExprCXX.h -@@ -3676,9 +3676,9 @@ - /// an implicit access if a qualifier is provided. - class CXXDependentScopeMemberExpr final - : public Expr, -- private llvm::TrailingObjects< -- CXXDependentScopeMemberExpr, NestedNameSpecifierLoc, DeclAccessPair, -- ASTTemplateKWAndArgsInfo, TemplateArgumentLoc> { -+ private llvm::TrailingObjects { - friend class ASTStmtReader; - friend class ASTStmtWriter; - friend TrailingObjects; -@@ -3691,15 +3691,17 @@ - /// implicit accesses. - QualType BaseType; - -+ /// The nested-name-specifier that precedes the member name, if any. -+ /// FIXME: This could be in principle store as a trailing object. -+ /// However the performance impact of doing so should be investigated first. -+ NestedNameSpecifierLoc QualifierLoc; -+ - /// The member to which this member expression refers, which - /// can be name, overloaded operator, or destructor. - /// - /// FIXME: could also be a template-id - DeclarationNameInfo MemberNameInfo; - -- /// The location of the '->' or '.' operator. -- SourceLocation OperatorLoc; -- - // CXXDependentScopeMemberExpr is followed by several trailing objects, - // some of which optional. They are in order: - // -@@ -3719,16 +3721,8 @@ - return CXXDependentScopeMemberExprBits.HasTemplateKWAndArgsInfo; - } - -- unsigned getNumUnqualifiedLookups() const { -- return CXXDependentScopeMemberExprBits.NumUnqualifiedLookups; -- } -- -- unsigned numTrailingObjects(OverloadToken) const { -- return hasQualifier(); -- } -- -- unsigned numTrailingObjects(OverloadToken) const { -- return getNumUnqualifiedLookups(); -+ bool hasFirstQualifierFoundInScope() const { -+ return CXXDependentScopeMemberExprBits.HasFirstQualifierFoundInScope; - } - - unsigned numTrailingObjects(OverloadToken) const { -@@ -3739,32 +3733,33 @@ - return getNumTemplateArgs(); - } - -+ unsigned numTrailingObjects(OverloadToken) const { -+ return hasFirstQualifierFoundInScope(); -+ } -+ - CXXDependentScopeMemberExpr(const ASTContext &Ctx, Expr *Base, - QualType BaseType, bool IsArrow, - SourceLocation OperatorLoc, - NestedNameSpecifierLoc QualifierLoc, - SourceLocation TemplateKWLoc, -- ArrayRef UnqualifiedLookups, -+ NamedDecl *FirstQualifierFoundInScope, - DeclarationNameInfo MemberNameInfo, - const TemplateArgumentListInfo *TemplateArgs); - -- CXXDependentScopeMemberExpr(EmptyShell Empty, bool HasQualifier, -- unsigned NumUnqualifiedLookups, -- bool HasTemplateKWAndArgsInfo); -+ CXXDependentScopeMemberExpr(EmptyShell Empty, bool HasTemplateKWAndArgsInfo, -+ bool HasFirstQualifierFoundInScope); - - public: - static CXXDependentScopeMemberExpr * - Create(const ASTContext &Ctx, Expr *Base, QualType BaseType, bool IsArrow, - SourceLocation OperatorLoc, NestedNameSpecifierLoc QualifierLoc, -- SourceLocation TemplateKWLoc, -- ArrayRef UnqualifiedLookups, -+ SourceLocation TemplateKWLoc, NamedDecl *FirstQualifierFoundInScope, - DeclarationNameInfo MemberNameInfo, - const TemplateArgumentListInfo *TemplateArgs); - - static CXXDependentScopeMemberExpr * -- CreateEmpty(const ASTContext &Ctx, bool HasQualifier, -- unsigned NumUnqualifiedLookups, bool HasTemplateKWAndArgsInfo, -- unsigned NumTemplateArgs); -+ CreateEmpty(const ASTContext &Ctx, bool HasTemplateKWAndArgsInfo, -+ unsigned NumTemplateArgs, bool HasFirstQualifierFoundInScope); - - /// True if this is an implicit access, i.e. one in which the - /// member being accessed was not written in the source. The source -@@ -3789,35 +3784,34 @@ - bool isArrow() const { return CXXDependentScopeMemberExprBits.IsArrow; } - - /// Retrieve the location of the '->' or '.' operator. -- SourceLocation getOperatorLoc() const { return OperatorLoc; } -- -- /// Determines whether this member expression had a nested-name-specifier -- /// prior to the name of the member, e.g., x->Base::foo. -- bool hasQualifier() const { -- return CXXDependentScopeMemberExprBits.HasQualifier; -- } -- -- /// If the member name was qualified, retrieves the nested-name-specifier -- /// that precedes the member name, with source-location information. -- NestedNameSpecifierLoc getQualifierLoc() const { -- if (!hasQualifier()) -- return NestedNameSpecifierLoc(); -- return *getTrailingObjects(); -+ SourceLocation getOperatorLoc() const { -+ return CXXDependentScopeMemberExprBits.OperatorLoc; - } - -- /// If the member name was qualified, retrieves the -- /// nested-name-specifier that precedes the member name. Otherwise, returns -- /// NULL. -+ /// Retrieve the nested-name-specifier that qualifies the member name. - NestedNameSpecifier *getQualifier() const { -- return getQualifierLoc().getNestedNameSpecifier(); -+ return QualifierLoc.getNestedNameSpecifier(); - } - -- /// Retrieve the declarations found by unqualified lookup for the first -- /// component name of the nested-name-specifier, if any. -- ArrayRef unqualified_lookups() const { -- if (!getNumUnqualifiedLookups()) -- return std::nullopt; -- return {getTrailingObjects(), getNumUnqualifiedLookups()}; -+ /// Retrieve the nested-name-specifier that qualifies the member -+ /// name, with source location information. -+ NestedNameSpecifierLoc getQualifierLoc() const { return QualifierLoc; } -+ -+ /// Retrieve the first part of the nested-name-specifier that was -+ /// found in the scope of the member access expression when the member access -+ /// was initially parsed. -+ /// -+ /// This function only returns a useful result when member access expression -+ /// uses a qualified member name, e.g., "x.Base::f". Here, the declaration -+ /// returned by this function describes what was found by unqualified name -+ /// lookup for the identifier "Base" within the scope of the member access -+ /// expression itself. At template instantiation time, this information is -+ /// combined with the results of name lookup into the type of the object -+ /// expression itself (the class type of x). -+ NamedDecl *getFirstQualifierFoundInScope() const { -+ if (!hasFirstQualifierFoundInScope()) -+ return nullptr; -+ return *getTrailingObjects(); - } - - /// Retrieve the name of the member that this expression refers to. -diff -ruN --strip-trailing-cr a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h ---- a/clang/include/clang/AST/Stmt.h -+++ b/clang/include/clang/AST/Stmt.h -@@ -1020,19 +1020,18 @@ - LLVM_PREFERRED_TYPE(bool) - unsigned IsArrow : 1; - -- /// True if this member expression used a nested-name-specifier to -- /// refer to the member, e.g., "x->Base::f". -- LLVM_PREFERRED_TYPE(bool) -- unsigned HasQualifier : 1; -- - /// Whether this member expression has info for explicit template - /// keyword and arguments. - LLVM_PREFERRED_TYPE(bool) - unsigned HasTemplateKWAndArgsInfo : 1; - -- /// Number of declarations found by unqualified lookup for the -- /// first component name of the nested-name-specifier. -- unsigned NumUnqualifiedLookups; -+ /// See getFirstQualifierFoundInScope() and the comment listing -+ /// the trailing objects. -+ LLVM_PREFERRED_TYPE(bool) -+ unsigned HasFirstQualifierFoundInScope : 1; -+ -+ /// The location of the '->' or '.' operator. -+ SourceLocation OperatorLoc; - }; - - class OverloadExprBitfields { -diff -ruN --strip-trailing-cr a/clang/include/clang/AST/UnresolvedSet.h b/clang/include/clang/AST/UnresolvedSet.h ---- a/clang/include/clang/AST/UnresolvedSet.h -+++ b/clang/include/clang/AST/UnresolvedSet.h -@@ -97,10 +97,6 @@ - decls().push_back(DeclAccessPair::make(D, AS)); - } - -- void addAllDecls(ArrayRef Other) { -- append(iterator(Other.begin()), iterator(Other.end())); -- } -- - /// Replaces the given declaration with the new one, once. - /// - /// \return true if the set changed -diff -ruN --strip-trailing-cr a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td ---- a/clang/include/clang/Basic/DiagnosticParseKinds.td -+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td -@@ -895,9 +895,10 @@ - "keyword">, InGroup>, - DefaultError; - --def ext_missing_dependent_template_keyword : ExtWarn< -- "use 'template' keyword to treat '%0' as a dependent template name">, -- InGroup>; -+def err_missing_dependent_template_keyword : Error< -+ "use 'template' keyword to treat '%0' as a dependent template name">; -+def warn_missing_dependent_template_keyword : ExtWarn< -+ "use 'template' keyword to treat '%0' as a dependent template name">; - - def ext_extern_template : Extension< - "extern templates are a C++11 extension">, InGroup; -diff -ruN --strip-trailing-cr a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h ---- a/clang/include/clang/Parse/Parser.h -+++ b/clang/include/clang/Parse/Parser.h -@@ -3368,11 +3368,15 @@ - BaseResult ParseBaseSpecifier(Decl *ClassDecl); - AccessSpecifier getAccessSpecifierIfPresent() const; - -- bool ParseUnqualifiedIdTemplateId( -- CXXScopeSpec &SS, ParsedType ObjectType, bool ObjectHadErrors, -- SourceLocation TemplateKWLoc, SourceLocation TildeLoc, -- IdentifierInfo *Name, SourceLocation NameLoc, bool EnteringContext, -- UnqualifiedId &Id, bool AssumeTemplateId); -+ bool ParseUnqualifiedIdTemplateId(CXXScopeSpec &SS, -+ ParsedType ObjectType, -+ bool ObjectHadErrors, -+ SourceLocation TemplateKWLoc, -+ IdentifierInfo *Name, -+ SourceLocation NameLoc, -+ bool EnteringContext, -+ UnqualifiedId &Id, -+ bool AssumeTemplateId); - bool ParseUnqualifiedIdOperator(CXXScopeSpec &SS, bool EnteringContext, - ParsedType ObjectType, - UnqualifiedId &Result); -diff -ruN --strip-trailing-cr a/clang/include/clang/Sema/DeclSpec.h b/clang/include/clang/Sema/DeclSpec.h ---- a/clang/include/clang/Sema/DeclSpec.h -+++ b/clang/include/clang/Sema/DeclSpec.h -@@ -75,7 +75,6 @@ - SourceRange Range; - NestedNameSpecifierLocBuilder Builder; - ArrayRef TemplateParamLists; -- ArrayRef UnqualifiedLookups; - - public: - SourceRange getRange() const { return Range; } -@@ -92,13 +91,6 @@ - return TemplateParamLists; - } - -- void setUnqualifiedLookups(ArrayRef Found) { -- UnqualifiedLookups = Found; -- } -- ArrayRef getUnqualifiedLookups() const { -- return UnqualifiedLookups; -- } -- - /// Retrieve the representation of the nested-name-specifier. - NestedNameSpecifier *getScopeRep() const { - return Builder.getRepresentation(); -diff -ruN --strip-trailing-cr a/clang/include/clang/Sema/Lookup.h b/clang/include/clang/Sema/Lookup.h ---- a/clang/include/clang/Sema/Lookup.h -+++ b/clang/include/clang/Sema/Lookup.h -@@ -483,15 +483,11 @@ - ResultKind = Found; - } - -- void addAllDecls(ArrayRef Other) { -- Decls.addAllDecls(Other); -- ResultKind = Found; -- } -- - /// Add all the declarations from another set of lookup - /// results. - void addAllDecls(const LookupResult &Other) { -- addAllDecls(Other.Decls.pairs()); -+ Decls.append(Other.Decls.begin(), Other.Decls.end()); -+ ResultKind = Found; - } - - /// Determine whether no result was found because we could not -diff -ruN --strip-trailing-cr a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h ---- a/clang/include/clang/Sema/Sema.h -+++ b/clang/include/clang/Sema/Sema.h -@@ -2802,8 +2802,7 @@ - /// (e.g., Base::), perform name lookup for that identifier as a - /// nested-name-specifier within the given scope, and return the result of - /// that name lookup. -- bool LookupFirstQualifierInScope(Scope *S, NestedNameSpecifier *NNS, -- UnresolvedSetImpl &R); -+ NamedDecl *FindFirstQualifierInScope(Scope *S, NestedNameSpecifier *NNS); - - /// Keeps information about an identifier in a nested-name-spec. - /// -@@ -2843,6 +2842,9 @@ - /// \param EnteringContext If true, enter the context specified by the - /// nested-name-specifier. - /// \param SS Optional nested name specifier preceding the identifier. -+ /// \param ScopeLookupResult Provides the result of name lookup within the -+ /// scope of the nested-name-specifier that was computed at template -+ /// definition time. - /// \param ErrorRecoveryLookup Specifies if the method is called to improve - /// error recovery and what kind of recovery is performed. - /// \param IsCorrectedToColon If not null, suggestion of replace '::' -> ':' -@@ -2851,6 +2853,11 @@ - /// not '::'. - /// \param OnlyNamespace If true, only considers namespaces in lookup. - /// -+ /// This routine differs only slightly from ActOnCXXNestedNameSpecifier, in -+ /// that it contains an extra parameter \p ScopeLookupResult, which provides -+ /// the result of name lookup within the scope of the nested-name-specifier -+ /// that was computed at template definition time. -+ /// - /// If ErrorRecoveryLookup is true, then this call is used to improve error - /// recovery. This means that it should not emit diagnostics, it should - /// just return true on failure. It also means it should only return a valid -@@ -2859,6 +2866,7 @@ - /// specifier. - bool BuildCXXNestedNameSpecifier(Scope *S, NestedNameSpecInfo &IdInfo, - bool EnteringContext, CXXScopeSpec &SS, -+ NamedDecl *ScopeLookupResult, - bool ErrorRecoveryLookup, - bool *IsCorrectedToColon = nullptr, - bool OnlyNamespace = false); -@@ -8558,12 +8566,11 @@ - const TemplateArgumentListInfo *TemplateArgs, - bool IsDefiniteInstance, const Scope *S); - -- ExprResult -- ActOnDependentMemberExpr(Expr *Base, QualType BaseType, bool IsArrow, -- SourceLocation OpLoc, const CXXScopeSpec &SS, -- SourceLocation TemplateKWLoc, -- const DeclarationNameInfo &NameInfo, -- const TemplateArgumentListInfo *TemplateArgs); -+ ExprResult ActOnDependentMemberExpr( -+ Expr *Base, QualType BaseType, bool IsArrow, SourceLocation OpLoc, -+ const CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, const DeclarationNameInfo &NameInfo, -+ const TemplateArgumentListInfo *TemplateArgs); - - /// The main callback when the parser finds something like - /// expression . [nested-name-specifier] identifier -@@ -8619,14 +8626,15 @@ - ExprResult BuildMemberReferenceExpr( - Expr *Base, QualType BaseType, SourceLocation OpLoc, bool IsArrow, - CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -- const DeclarationNameInfo &NameInfo, -+ NamedDecl *FirstQualifierInScope, const DeclarationNameInfo &NameInfo, - const TemplateArgumentListInfo *TemplateArgs, const Scope *S, - ActOnMemberAccessExtraArgs *ExtraArgs = nullptr); - - ExprResult - BuildMemberReferenceExpr(Expr *Base, QualType BaseType, SourceLocation OpLoc, - bool IsArrow, const CXXScopeSpec &SS, -- SourceLocation TemplateKWLoc, LookupResult &R, -+ SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, LookupResult &R, - const TemplateArgumentListInfo *TemplateArgs, - const Scope *S, bool SuppressQualifierCheck = false, - ActOnMemberAccessExtraArgs *ExtraArgs = nullptr); -@@ -11114,14 +11122,15 @@ - QualType ObjectType, bool EnteringContext, - RequiredTemplateKind RequiredTemplate = SourceLocation(), - AssumedTemplateKind *ATK = nullptr, -- bool AllowTypoCorrection = true, bool MayBeNNS = false); -+ bool AllowTypoCorrection = true); - -- TemplateNameKind -- isTemplateName(Scope *S, CXXScopeSpec &SS, bool hasTemplateKeyword, -- const UnqualifiedId &Name, ParsedType ObjectType, -- bool EnteringContext, TemplateTy &Template, -- bool &MemberOfUnknownSpecialization, -- bool Disambiguation = false, bool MayBeNNS = false); -+ TemplateNameKind isTemplateName(Scope *S, CXXScopeSpec &SS, -+ bool hasTemplateKeyword, -+ const UnqualifiedId &Name, -+ ParsedType ObjectType, bool EnteringContext, -+ TemplateTy &Template, -+ bool &MemberOfUnknownSpecialization, -+ bool Disambiguation = false); - - /// Try to resolve an undeclared template name as a type template. - /// -@@ -11450,11 +11459,12 @@ - /// For example, given "x.MetaFun::template apply", the scope specifier - /// \p SS will be "MetaFun::", \p TemplateKWLoc contains the location - /// of the "template" keyword, and "apply" is the \p Name. -- TemplateNameKind -- ActOnTemplateName(Scope *S, CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -- const UnqualifiedId &Name, ParsedType ObjectType, -- bool EnteringContext, TemplateTy &Template, -- bool AllowInjectedClassName = false, bool MayBeNNS = false); -+ TemplateNameKind ActOnTemplateName(Scope *S, CXXScopeSpec &SS, -+ SourceLocation TemplateKWLoc, -+ const UnqualifiedId &Name, -+ ParsedType ObjectType, -+ bool EnteringContext, TemplateTy &Template, -+ bool AllowInjectedClassName = false); - - DeclResult ActOnClassTemplateSpecialization( - Scope *S, unsigned TagSpec, TagUseKind TUK, SourceLocation KWLoc, -diff -ruN --strip-trailing-cr a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp ---- a/clang/lib/AST/ASTImporter.cpp -+++ b/clang/lib/AST/ASTImporter.cpp -@@ -8439,14 +8439,8 @@ - auto ToOperatorLoc = importChecked(Err, E->getOperatorLoc()); - auto ToQualifierLoc = importChecked(Err, E->getQualifierLoc()); - auto ToTemplateKeywordLoc = importChecked(Err, E->getTemplateKeywordLoc()); -- -- UnresolvedSet<8> ToUnqualifiedLookups; -- for (auto D : E->unqualified_lookups()) -- if (auto ToDOrErr = import(D.getDecl())) -- ToUnqualifiedLookups.addDecl(*ToDOrErr); -- else -- return ToDOrErr.takeError(); -- -+ auto ToFirstQualifierFoundInScope = -+ importChecked(Err, E->getFirstQualifierFoundInScope()); - if (Err) - return std::move(Err); - -@@ -8480,7 +8474,7 @@ - - return CXXDependentScopeMemberExpr::Create( - Importer.getToContext(), ToBase, ToType, E->isArrow(), ToOperatorLoc, -- ToQualifierLoc, ToTemplateKeywordLoc, ToUnqualifiedLookups.pairs(), -+ ToQualifierLoc, ToTemplateKeywordLoc, ToFirstQualifierFoundInScope, - ToMemberNameInfo, ResInfo); - } - -diff -ruN --strip-trailing-cr a/clang/lib/AST/ExprCXX.cpp b/clang/lib/AST/ExprCXX.cpp ---- a/clang/lib/AST/ExprCXX.cpp -+++ b/clang/lib/AST/ExprCXX.cpp -@@ -1489,27 +1489,19 @@ - CXXDependentScopeMemberExpr::CXXDependentScopeMemberExpr( - const ASTContext &Ctx, Expr *Base, QualType BaseType, bool IsArrow, - SourceLocation OperatorLoc, NestedNameSpecifierLoc QualifierLoc, -- SourceLocation TemplateKWLoc, ArrayRef UnqualifiedLookups, -+ SourceLocation TemplateKWLoc, NamedDecl *FirstQualifierFoundInScope, - DeclarationNameInfo MemberNameInfo, - const TemplateArgumentListInfo *TemplateArgs) - : Expr(CXXDependentScopeMemberExprClass, Ctx.DependentTy, VK_LValue, - OK_Ordinary), -- Base(Base), BaseType(BaseType), MemberNameInfo(MemberNameInfo), -- OperatorLoc(OperatorLoc) { -+ Base(Base), BaseType(BaseType), QualifierLoc(QualifierLoc), -+ MemberNameInfo(MemberNameInfo) { - CXXDependentScopeMemberExprBits.IsArrow = IsArrow; -- CXXDependentScopeMemberExprBits.HasQualifier = QualifierLoc.hasQualifier(); -- CXXDependentScopeMemberExprBits.NumUnqualifiedLookups = -- UnqualifiedLookups.size(); - CXXDependentScopeMemberExprBits.HasTemplateKWAndArgsInfo = - (TemplateArgs != nullptr) || TemplateKWLoc.isValid(); -- -- if (hasQualifier()) -- new (getTrailingObjects()) -- NestedNameSpecifierLoc(QualifierLoc); -- -- std::uninitialized_copy_n(UnqualifiedLookups.data(), -- UnqualifiedLookups.size(), -- getTrailingObjects()); -+ CXXDependentScopeMemberExprBits.HasFirstQualifierFoundInScope = -+ FirstQualifierFoundInScope != nullptr; -+ CXXDependentScopeMemberExprBits.OperatorLoc = OperatorLoc; - - if (TemplateArgs) { - auto Deps = TemplateArgumentDependence::None; -@@ -1521,59 +1513,54 @@ - TemplateKWLoc); - } - -+ if (hasFirstQualifierFoundInScope()) -+ *getTrailingObjects() = FirstQualifierFoundInScope; - setDependence(computeDependence(this)); - } - - CXXDependentScopeMemberExpr::CXXDependentScopeMemberExpr( -- EmptyShell Empty, bool HasQualifier, unsigned NumUnqualifiedLookups, -- bool HasTemplateKWAndArgsInfo) -+ EmptyShell Empty, bool HasTemplateKWAndArgsInfo, -+ bool HasFirstQualifierFoundInScope) - : Expr(CXXDependentScopeMemberExprClass, Empty) { -- CXXDependentScopeMemberExprBits.HasQualifier = HasQualifier; -- CXXDependentScopeMemberExprBits.NumUnqualifiedLookups = NumUnqualifiedLookups; - CXXDependentScopeMemberExprBits.HasTemplateKWAndArgsInfo = - HasTemplateKWAndArgsInfo; -+ CXXDependentScopeMemberExprBits.HasFirstQualifierFoundInScope = -+ HasFirstQualifierFoundInScope; - } - - CXXDependentScopeMemberExpr *CXXDependentScopeMemberExpr::Create( - const ASTContext &Ctx, Expr *Base, QualType BaseType, bool IsArrow, - SourceLocation OperatorLoc, NestedNameSpecifierLoc QualifierLoc, -- SourceLocation TemplateKWLoc, ArrayRef UnqualifiedLookups, -+ SourceLocation TemplateKWLoc, NamedDecl *FirstQualifierFoundInScope, - DeclarationNameInfo MemberNameInfo, - const TemplateArgumentListInfo *TemplateArgs) { -- bool HasQualifier = QualifierLoc.hasQualifier(); -- unsigned NumUnqualifiedLookups = UnqualifiedLookups.size(); -- assert(!NumUnqualifiedLookups || HasQualifier); - bool HasTemplateKWAndArgsInfo = - (TemplateArgs != nullptr) || TemplateKWLoc.isValid(); - unsigned NumTemplateArgs = TemplateArgs ? TemplateArgs->size() : 0; -- unsigned Size = -- totalSizeToAlloc( -- HasQualifier, NumUnqualifiedLookups, HasTemplateKWAndArgsInfo, -- NumTemplateArgs); -+ bool HasFirstQualifierFoundInScope = FirstQualifierFoundInScope != nullptr; -+ -+ unsigned Size = totalSizeToAlloc( -+ HasTemplateKWAndArgsInfo, NumTemplateArgs, HasFirstQualifierFoundInScope); - - void *Mem = Ctx.Allocate(Size, alignof(CXXDependentScopeMemberExpr)); - return new (Mem) CXXDependentScopeMemberExpr( - Ctx, Base, BaseType, IsArrow, OperatorLoc, QualifierLoc, TemplateKWLoc, -- UnqualifiedLookups, MemberNameInfo, TemplateArgs); -+ FirstQualifierFoundInScope, MemberNameInfo, TemplateArgs); - } - - CXXDependentScopeMemberExpr *CXXDependentScopeMemberExpr::CreateEmpty( -- const ASTContext &Ctx, bool HasQualifier, unsigned NumUnqualifiedLookups, -- bool HasTemplateKWAndArgsInfo, unsigned NumTemplateArgs) { -- assert(!NumTemplateArgs || HasTemplateKWAndArgsInfo); -- assert(!NumUnqualifiedLookups || HasQualifier); -- -- unsigned Size = -- totalSizeToAlloc( -- HasQualifier, NumUnqualifiedLookups, HasTemplateKWAndArgsInfo, -- NumTemplateArgs); -+ const ASTContext &Ctx, bool HasTemplateKWAndArgsInfo, -+ unsigned NumTemplateArgs, bool HasFirstQualifierFoundInScope) { -+ assert(NumTemplateArgs == 0 || HasTemplateKWAndArgsInfo); -+ -+ unsigned Size = totalSizeToAlloc( -+ HasTemplateKWAndArgsInfo, NumTemplateArgs, HasFirstQualifierFoundInScope); - - void *Mem = Ctx.Allocate(Size, alignof(CXXDependentScopeMemberExpr)); -- return new (Mem) CXXDependentScopeMemberExpr(EmptyShell(), HasQualifier, -- NumUnqualifiedLookups, -- HasTemplateKWAndArgsInfo); -+ return new (Mem) CXXDependentScopeMemberExpr( -+ EmptyShell(), HasTemplateKWAndArgsInfo, HasFirstQualifierFoundInScope); - } - - CXXThisExpr *CXXThisExpr::Create(const ASTContext &Ctx, SourceLocation L, -diff -ruN --strip-trailing-cr a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp ---- a/clang/lib/AST/ItaniumMangle.cpp -+++ b/clang/lib/AST/ItaniumMangle.cpp -@@ -594,10 +594,11 @@ - void mangleMemberExprBase(const Expr *base, bool isArrow); - void mangleMemberExpr(const Expr *base, bool isArrow, - NestedNameSpecifier *qualifier, -- ArrayRef UnqualifiedLookups, -+ NamedDecl *firstQualifierLookup, - DeclarationName name, - const TemplateArgumentLoc *TemplateArgs, -- unsigned NumTemplateArgs, unsigned knownArity); -+ unsigned NumTemplateArgs, -+ unsigned knownArity); - void mangleCastExpression(const Expr *E, StringRef CastEncoding); - void mangleInitListElements(const InitListExpr *InitList); - void mangleRequirement(SourceLocation RequiresExprLoc, -@@ -4495,11 +4496,14 @@ - } - - /// Mangles a member expression. --void CXXNameMangler::mangleMemberExpr( -- const Expr *base, bool isArrow, NestedNameSpecifier *qualifier, -- ArrayRef UnqualifiedLookups, DeclarationName member, -- const TemplateArgumentLoc *TemplateArgs, unsigned NumTemplateArgs, -- unsigned arity) { -+void CXXNameMangler::mangleMemberExpr(const Expr *base, -+ bool isArrow, -+ NestedNameSpecifier *qualifier, -+ NamedDecl *firstQualifierLookup, -+ DeclarationName member, -+ const TemplateArgumentLoc *TemplateArgs, -+ unsigned NumTemplateArgs, -+ unsigned arity) { - // ::= dt - // ::= pt - if (base) -@@ -4981,9 +4985,11 @@ - case Expr::MemberExprClass: { - NotPrimaryExpr(); - const MemberExpr *ME = cast(E); -- mangleMemberExpr(ME->getBase(), ME->isArrow(), ME->getQualifier(), -- std::nullopt, ME->getMemberDecl()->getDeclName(), -- ME->getTemplateArgs(), ME->getNumTemplateArgs(), Arity); -+ mangleMemberExpr(ME->getBase(), ME->isArrow(), -+ ME->getQualifier(), nullptr, -+ ME->getMemberDecl()->getDeclName(), -+ ME->getTemplateArgs(), ME->getNumTemplateArgs(), -+ Arity); - break; - } - -@@ -4991,9 +4997,10 @@ - NotPrimaryExpr(); - const UnresolvedMemberExpr *ME = cast(E); - mangleMemberExpr(ME->isImplicitAccess() ? nullptr : ME->getBase(), -- ME->isArrow(), ME->getQualifier(), std::nullopt, -- ME->getMemberName(), ME->getTemplateArgs(), -- ME->getNumTemplateArgs(), Arity); -+ ME->isArrow(), ME->getQualifier(), nullptr, -+ ME->getMemberName(), -+ ME->getTemplateArgs(), ME->getNumTemplateArgs(), -+ Arity); - break; - } - -@@ -5003,8 +5010,10 @@ - = cast(E); - mangleMemberExpr(ME->isImplicitAccess() ? nullptr : ME->getBase(), - ME->isArrow(), ME->getQualifier(), -- ME->unqualified_lookups(), ME->getMember(), -- ME->getTemplateArgs(), ME->getNumTemplateArgs(), Arity); -+ ME->getFirstQualifierFoundInScope(), -+ ME->getMember(), -+ ME->getTemplateArgs(), ME->getNumTemplateArgs(), -+ Arity); - break; - } - diff -ruN --strip-trailing-cr a/clang/lib/Basic/CMakeLists.txt b/clang/lib/Basic/CMakeLists.txt --- a/clang/lib/Basic/CMakeLists.txt +++ b/clang/lib/Basic/CMakeLists.txt @@ -779,2042 +132,96 @@ diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/Le64.h b/clang/lib/Basic + bool hasProtectedVisibility() const override { return false; } +}; + -+} // namespace targets -+} // namespace clang -+#endif // LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/OSTargets.h b/clang/lib/Basic/Targets/OSTargets.h ---- a/clang/lib/Basic/Targets/OSTargets.h -+++ b/clang/lib/Basic/Targets/OSTargets.h -@@ -841,6 +841,9 @@ - "i64:64-i128:128-n8:16:32:64-S128"); - } else if (Triple.getArch() == llvm::Triple::mipsel) { - // Handled on mips' setDataLayout. -+ } else { -+ assert(Triple.getArch() == llvm::Triple::le32); -+ this->resetDataLayout("e-p:32:32-i64:64"); - } - } - }; -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets.cpp b/clang/lib/Basic/Targets.cpp ---- a/clang/lib/Basic/Targets.cpp -+++ b/clang/lib/Basic/Targets.cpp -@@ -23,6 +23,7 @@ - #include "Targets/DirectX.h" - #include "Targets/Hexagon.h" - #include "Targets/Lanai.h" -+#include "Targets/Le64.h" - #include "Targets/LoongArch.h" - #include "Targets/M68k.h" - #include "Targets/MSP430.h" -@@ -343,6 +344,17 @@ - return std::make_unique(Triple, Opts); - } - -+ case llvm::Triple::le32: -+ switch (os) { -+ case llvm::Triple::NaCl: -+ return std::make_unique>(Triple, Opts); -+ default: -+ return nullptr; -+ } -+ -+ case llvm::Triple::le64: -+ return std::make_unique(Triple, Opts); -+ - case llvm::Triple::ppc: - switch (os) { - case llvm::Triple::Linux: -diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp ---- a/clang/lib/CodeGen/CodeGenModule.cpp -+++ b/clang/lib/CodeGen/CodeGenModule.cpp -@@ -116,6 +116,8 @@ - default: - return createDefaultTargetCodeGenInfo(CGM); - -+ case llvm::Triple::le32: -+ return createPNaClTargetCodeGenInfo(CGM); - case llvm::Triple::m68k: - return createM68kTargetCodeGenInfo(CGM); - case llvm::Triple::mips: -diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp ---- a/clang/lib/CodeGen/ItaniumCXXABI.cpp -+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp -@@ -576,6 +576,13 @@ - return new XLCXXABI(CGM); - - case TargetCXXABI::GenericItanium: -+ if (CGM.getContext().getTargetInfo().getTriple().getArch() -+ == llvm::Triple::le32) { -+ // For PNaCl, use ARM-style method pointers so that PNaCl code -+ // does not assume anything about the alignment of function -+ // pointers. -+ return new ItaniumCXXABI(CGM, /*UseARMMethodPtrABI=*/true); -+ } - return new ItaniumCXXABI(CGM); - - case TargetCXXABI::Microsoft: -diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp ---- a/clang/lib/Driver/ToolChains/Clang.cpp -+++ b/clang/lib/Driver/ToolChains/Clang.cpp -@@ -3815,6 +3815,12 @@ - if (UseBuiltins) - A->render(Args, CmdArgs); - } -+ -+ // le32-specific flags: -+ // -fno-math-builtin: clang should not convert math builtins to intrinsics -+ // by default. -+ if (TC.getArch() == llvm::Triple::le32) -+ CmdArgs.push_back("-fno-math-builtin"); - } - - bool Driver::getDefaultModuleCachePath(SmallVectorImpl &Result) { -diff -ruN --strip-trailing-cr a/clang/lib/Parse/ParseExpr.cpp b/clang/lib/Parse/ParseExpr.cpp ---- a/clang/lib/Parse/ParseExpr.cpp -+++ b/clang/lib/Parse/ParseExpr.cpp -@@ -2343,9 +2343,10 @@ - } - - if (!LHS.isInvalid()) -- LHS = Actions.ActOnMemberAccessExpr( -- getCurScope(), LHS.get(), OpLoc, OpKind, SS, TemplateKWLoc, Name, -- CurParsedObjCImpl ? CurParsedObjCImpl->Dcl : nullptr); -+ LHS = Actions.ActOnMemberAccessExpr(getCurScope(), LHS.get(), OpLoc, -+ OpKind, SS, TemplateKWLoc, Name, -+ CurParsedObjCImpl ? CurParsedObjCImpl->Dcl -+ : nullptr); - if (!LHS.isInvalid()) { - if (Tok.is(tok::less)) - checkPotentialAngleBracket(LHS); -diff -ruN --strip-trailing-cr a/clang/lib/Parse/ParseExprCXX.cpp b/clang/lib/Parse/ParseExprCXX.cpp ---- a/clang/lib/Parse/ParseExprCXX.cpp -+++ b/clang/lib/Parse/ParseExprCXX.cpp -@@ -100,8 +100,7 @@ - bool MemberOfUnknownSpecialization; - if (!Actions.isTemplateName(getCurScope(), SS, /*hasTemplateKeyword=*/false, - TemplateName, ObjectType, EnteringContext, -- Template, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, /*MayBeNNS=*/true)) -+ Template, MemberOfUnknownSpecialization)) - return; - - FixDigraph(*this, PP, Next, SecondToken, tok::unknown, -@@ -354,8 +353,7 @@ - TemplateTy Template; - TemplateNameKind TNK = Actions.ActOnTemplateName( - getCurScope(), SS, TemplateKWLoc, TemplateName, ObjectType, -- EnteringContext, Template, /*AllowInjectedClassName*/ true, -- /*MayBeNNS=*/true); -+ EnteringContext, Template, /*AllowInjectedClassName*/ true); - if (AnnotateTemplateIdToken(Template, TNK, SS, TemplateKWLoc, - TemplateName, false)) - return true; -@@ -407,6 +405,7 @@ - : TemplateId->TemplateNameLoc; - SS.SetInvalid(SourceRange(StartLoc, CCLoc)); - } -+ - continue; - } - -@@ -529,19 +528,18 @@ - UnqualifiedId TemplateName; - TemplateName.setIdentifier(&II, Tok.getLocation()); - bool MemberOfUnknownSpecialization; -- if (TemplateNameKind TNK = Actions.isTemplateName( -- getCurScope(), SS, -- /*hasTemplateKeyword=*/false, TemplateName, ObjectType, -- EnteringContext, Template, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, -- /*MayBeNNS=*/true)) { -+ if (TemplateNameKind TNK = Actions.isTemplateName(getCurScope(), SS, -+ /*hasTemplateKeyword=*/false, -+ TemplateName, -+ ObjectType, -+ EnteringContext, -+ Template, -+ MemberOfUnknownSpecialization)) { - // If lookup didn't find anything, we treat the name as a template-name - // anyway. C++20 requires this, and in prior language modes it improves - // error recovery. But before we commit to this, check that we actually - // have something that looks like a template-argument-list next. -- if (!IsTypename && -- (TNK == TNK_Undeclared_template || -- (!HasScopeSpecifier && ObjectType)) && -+ if (!IsTypename && TNK == TNK_Undeclared_template && - isTemplateArgumentList(1) == TPResult::False) - break; - -@@ -568,7 +566,11 @@ - // member of an unknown specialization. However, this will only - // parse correctly as a template, so suggest the keyword 'template' - // before 'getAs' and treat this as a dependent template name. -- Diag(Tok.getLocation(), diag::ext_missing_dependent_template_keyword) -+ unsigned DiagID = diag::err_missing_dependent_template_keyword; -+ if (getLangOpts().MicrosoftExt) -+ DiagID = diag::warn_missing_dependent_template_keyword; -+ -+ Diag(Tok.getLocation(), DiagID) - << II.getName() - << FixItHint::CreateInsertion(Tok.getLocation(), "template "); - } -@@ -1918,12 +1920,12 @@ - // argument list. This affects examples such as - // void f(auto *p) { p->~X(); } - // ... but there's no ambiguity, and nowhere to write 'template' in such an -- // example, so we accept it anyway -- if (Tok.is(tok::less) && ParseUnqualifiedIdTemplateId( -- SS, ObjectType, Base && Base->containsErrors(), -- /*TemplateKWLoc=*/SourceLocation(), TildeLoc, -- Name, NameLoc, false, SecondTypeName, -- /*AssumeTemplateId=*/true)) -+ // example, so we accept it anyway. -+ if (Tok.is(tok::less) && -+ ParseUnqualifiedIdTemplateId( -+ SS, ObjectType, Base && Base->containsErrors(), SourceLocation(), -+ Name, NameLoc, false, SecondTypeName, -+ /*AssumeTemplateId=*/true)) - return ExprError(); - - return Actions.ActOnPseudoDestructorExpr(getCurScope(), Base, OpLoc, OpKind, -@@ -2530,9 +2532,8 @@ - /// \returns true if a parse error occurred, false otherwise. - bool Parser::ParseUnqualifiedIdTemplateId( - CXXScopeSpec &SS, ParsedType ObjectType, bool ObjectHadErrors, -- SourceLocation TemplateKWLoc, SourceLocation TildeLoc, IdentifierInfo *Name, -- SourceLocation NameLoc, bool EnteringContext, UnqualifiedId &Id, -- bool AssumeTemplateId) { -+ SourceLocation TemplateKWLoc, IdentifierInfo *Name, SourceLocation NameLoc, -+ bool EnteringContext, UnqualifiedId &Id, bool AssumeTemplateId) { - assert(Tok.is(tok::less) && "Expected '<' to finish parsing a template-id"); - - TemplateTy Template; -@@ -2546,14 +2547,13 @@ - // this template-id is used to form a nested-name-specifier or not. - TNK = Actions.ActOnTemplateName(getCurScope(), SS, TemplateKWLoc, Id, - ObjectType, EnteringContext, Template, -- /*AllowInjectedClassName=*/true, -- TildeLoc.isValid()); -+ /*AllowInjectedClassName*/ true); - } else { - bool MemberOfUnknownSpecialization; -- TNK = Actions.isTemplateName( -- getCurScope(), SS, TemplateKWLoc.isValid(), Id, ObjectType, -- EnteringContext, Template, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, TildeLoc.isValid()); -+ TNK = Actions.isTemplateName(getCurScope(), SS, -+ TemplateKWLoc.isValid(), Id, -+ ObjectType, EnteringContext, Template, -+ MemberOfUnknownSpecialization); - // If lookup found nothing but we're assuming that this is a template - // name, double-check that makes sense syntactically before committing - // to it. -@@ -2580,13 +2580,13 @@ - else - Name += Id.Identifier->getName(); - } -- Diag(Id.StartLocation, diag::ext_missing_dependent_template_keyword) -+ Diag(Id.StartLocation, diag::err_missing_dependent_template_keyword) - << Name - << FixItHint::CreateInsertion(Id.StartLocation, "template "); - } - TNK = Actions.ActOnTemplateName( - getCurScope(), SS, TemplateKWLoc, Id, ObjectType, EnteringContext, -- Template, /*AllowInjectedClassName=*/true, TildeLoc.isValid()); -+ Template, /*AllowInjectedClassName*/ true); - } else if (TNK == TNK_Non_template) { - return false; - } -@@ -2611,16 +2611,14 @@ - bool MemberOfUnknownSpecialization; - TemplateName.setIdentifier(Name, NameLoc); - if (ObjectType) { -- TNK = Actions.ActOnTemplateName(getCurScope(), SS, TemplateKWLoc, -- TemplateName, ObjectType, EnteringContext, -- Template, /*AllowInjectedClassName=*/true, -- /*MayBeNNS=*/true); -+ TNK = Actions.ActOnTemplateName( -+ getCurScope(), SS, TemplateKWLoc, TemplateName, ObjectType, -+ EnteringContext, Template, /*AllowInjectedClassName*/ true); - } else { - TNK = Actions.isTemplateName(getCurScope(), SS, TemplateKWLoc.isValid(), -- TemplateName, ObjectType, EnteringContext, -- Template, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, -- /*MayBeNNS=*/true); -+ TemplateName, ObjectType, -+ EnteringContext, Template, -+ MemberOfUnknownSpecialization); - - if (TNK == TNK_Non_template && !Id.DestructorName.get()) { - Diag(NameLoc, diag::err_destructor_template_id) -@@ -2682,7 +2680,7 @@ - if (Id.getKind() == UnqualifiedIdKind::IK_ConstructorName) - Id.setConstructorName(Type.get(), NameLoc, RAngleLoc); - else -- Id.setDestructorName(TildeLoc, Type.get(), RAngleLoc); -+ Id.setDestructorName(Id.StartLocation, Type.get(), RAngleLoc); - - return false; - } -@@ -3030,9 +3028,8 @@ - if (Tok.is(tok::less)) - return ParseUnqualifiedIdTemplateId( - SS, ObjectType, ObjectHadErrors, -- TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), -- /*TildeLoc=*/SourceLocation(), Id, IdLoc, EnteringContext, Result, -- TemplateSpecified); -+ TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), Id, IdLoc, -+ EnteringContext, Result, TemplateSpecified); - - if (TemplateSpecified) { - TemplateNameKind TNK = -@@ -3127,15 +3124,13 @@ - Tok.is(tok::less)) - return ParseUnqualifiedIdTemplateId( - SS, ObjectType, ObjectHadErrors, -- TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), -- /*TildeLoc=*/SourceLocation(), /*Name=*/nullptr, -- /*NameLoc=*/SourceLocation(), EnteringContext, Result, -- TemplateSpecified); -+ TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), nullptr, -+ SourceLocation(), EnteringContext, Result, TemplateSpecified); - else if (TemplateSpecified && - Actions.ActOnTemplateName( - getCurScope(), SS, *TemplateKWLoc, Result, ObjectType, - EnteringContext, Template, -- /*AllowInjectedClassName=*/true) == TNK_Non_template) -+ /*AllowInjectedClassName*/ true) == TNK_Non_template) - return true; - - return false; -@@ -3225,8 +3220,8 @@ - Result.setDestructorName(TildeLoc, nullptr, ClassNameLoc); - return ParseUnqualifiedIdTemplateId( - SS, ObjectType, ObjectHadErrors, -- TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), TildeLoc, -- ClassName, ClassNameLoc, EnteringContext, Result, TemplateSpecified); -+ TemplateKWLoc ? *TemplateKWLoc : SourceLocation(), ClassName, -+ ClassNameLoc, EnteringContext, Result, TemplateSpecified); - } - - // Note that this is a destructor name. -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp ---- a/clang/lib/Sema/SemaCoroutine.cpp -+++ b/clang/lib/Sema/SemaCoroutine.cpp -@@ -306,8 +306,8 @@ - // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. - CXXScopeSpec SS; - ExprResult Result = S.BuildMemberReferenceExpr( -- Base, Base->getType(), Loc, /*IsPtr=*/false, SS, SourceLocation(), -- NameInfo, /*TemplateArgs=*/nullptr, -+ Base, Base->getType(), Loc, /*IsPtr=*/false, SS, -+ SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr, - /*Scope=*/nullptr); - if (Result.isInvalid()) - return ExprError(); -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp ---- a/clang/lib/Sema/SemaCXXScopeSpec.cpp -+++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp -@@ -356,41 +356,29 @@ - return false; - } - --/// If the given nested-name-specifier begins with a bare identifier --/// (e.g., Base::), perform name lookup for that identifier as a --/// nested-name-specifier within the given scope, and return the result of that --/// name lookup. --bool Sema::LookupFirstQualifierInScope(Scope *S, NestedNameSpecifier *NNS, -- UnresolvedSetImpl &R) { -- if (!S) -- return false; -+NamedDecl *Sema::FindFirstQualifierInScope(Scope *S, NestedNameSpecifier *NNS) { -+ if (!S || !NNS) -+ return nullptr; - - while (NNS->getPrefix()) - NNS = NNS->getPrefix(); - -- // FIXME: This is a rather nasty hack! Ideally we should get the results -- // from LookupTemplateName/BuildCXXNestedNameSpecifier. -- const IdentifierInfo *II = NNS->getAsIdentifier(); -- if (!II) { -- if (const auto *DTST = -- dyn_cast_if_present( -- NNS->getAsType())) -- II = DTST->getIdentifier(); -- else -- return false; -- } -- assert(II && "Missing first qualifier in scope"); -- LookupResult Found(*this, II, SourceLocation(), -- NNS->getAsIdentifier() ? LookupNestedNameSpecifierName -- : LookupOrdinaryName); -+ if (NNS->getKind() != NestedNameSpecifier::Identifier) -+ return nullptr; -+ -+ LookupResult Found(*this, NNS->getAsIdentifier(), SourceLocation(), -+ LookupNestedNameSpecifierName); - LookupName(Found, S); -+ assert(!Found.isAmbiguous() && "Cannot handle ambiguities here yet"); - -- if (Found.empty()) -- return false; -+ if (!Found.isSingleResult()) -+ return nullptr; - -- R.addAllDecls(Found.asUnresolvedSet().pairs()); -- Found.suppressDiagnostics(); -- return true; -+ NamedDecl *Result = Found.getFoundDecl(); -+ if (isAcceptableNestedNameSpecifier(Result)) -+ return Result; -+ -+ return nullptr; - } - - namespace { -@@ -419,82 +407,112 @@ - - bool Sema::BuildCXXNestedNameSpecifier(Scope *S, NestedNameSpecInfo &IdInfo, - bool EnteringContext, CXXScopeSpec &SS, -+ NamedDecl *ScopeLookupResult, - bool ErrorRecoveryLookup, - bool *IsCorrectedToColon, - bool OnlyNamespace) { - if (IdInfo.Identifier->isEditorPlaceholder()) - return true; -- if (IsCorrectedToColon) -- *IsCorrectedToColon = false; -- -- QualType ObjectType = GetTypeFromParser(IdInfo.ObjectType); - LookupResult Found(*this, IdInfo.Identifier, IdInfo.IdentifierLoc, - OnlyNamespace ? LookupNamespaceName - : LookupNestedNameSpecifierName); -+ QualType ObjectType = GetTypeFromParser(IdInfo.ObjectType); - -- // C++ [basic.lookup.qual.general]p3: -- // Qualified name lookup in a class, namespace, or enumeration performs a -- // search of the scope associated with it except as specified below. -- LookupParsedName(Found, S, &SS, ObjectType, -- /*AllowBuiltinCreation=*/false, EnteringContext); -- -- // C++ [basic.lookup.qual.general]p3: -- // [...] Unless otherwise specified, a qualified name undergoes qualified -- // name lookup in its lookup context from the point where it appears unless -- // the lookup context either is dependent and is not the current -- // instantiation or is not a class or class template. -- if (Found.wasNotFoundInCurrentInstantiation()) { -- // Don't speculate if we're just trying to improve error recovery. -- if (ErrorRecoveryLookup) -- return true; -- -- // The lookup context is dependent and either: -- // - it is not the current instantiation, or -- // - it is the current instantiation, it has at least one dependent base -- // class, and qualified lookup found nothing. -- // Build a dependent nested-name-specifier. We will lookup the name again -- // during instantiation. -- SS.Extend(Context, IdInfo.Identifier, IdInfo.IdentifierLoc, IdInfo.CCLoc); -- return false; -+ // Determine where to perform name lookup -+ DeclContext *LookupCtx = nullptr; -+ bool isDependent = false; -+ if (IsCorrectedToColon) -+ *IsCorrectedToColon = false; -+ if (!ObjectType.isNull()) { -+ // This nested-name-specifier occurs in a member access expression, e.g., -+ // x->B::f, and we are looking into the type of the object. -+ assert(!SS.isSet() && "ObjectType and scope specifier cannot coexist"); -+ LookupCtx = computeDeclContext(ObjectType); -+ isDependent = ObjectType->isDependentType(); -+ } else if (SS.isSet()) { -+ // This nested-name-specifier occurs after another nested-name-specifier, -+ // so look into the context associated with the prior nested-name-specifier. -+ LookupCtx = computeDeclContext(SS, EnteringContext); -+ isDependent = isDependentScopeSpecifier(SS); -+ Found.setContextRange(SS.getRange()); - } - - bool ObjectTypeSearchedInScope = false; -+ if (LookupCtx) { -+ // Perform "qualified" name lookup into the declaration context we -+ // computed, which is either the type of the base of a member access -+ // expression or the declaration context associated with a prior -+ // nested-name-specifier. -+ -+ // The declaration context must be complete. -+ if (!LookupCtx->isDependentContext() && -+ RequireCompleteDeclContext(SS, LookupCtx)) -+ return true; - -- // C++ [basic.lookup.qual.general]p2: -- // A member-qualified name is the (unique) component name, if any, of -- // - an unqualified-id or -- // - a nested-name-specifier of the form type-name :: or namespace-name :: -- // in the id-expression of a class member access expression. -- // -- // C++ [basic.lookup.qual.general]p3: -- // [...] If nothing is found by qualified lookup for a member-qualified -- // name that is the terminal name of a nested-name-specifier and is not -- // dependent, it undergoes unqualified lookup. -- // -- // In 'x.A::B::y', 'A' will undergo unqualified lookup if qualified lookup -- // in the type of 'x' finds nothing. If the lookup context is dependent, -- // we perform the unqualified lookup in the template definition context -- // and store the results so we can replicate the lookup during instantiation. -- if (Found.empty() && !ObjectType.isNull()) { -- if (S) { -- LookupName(Found, S); -- } else if (!SS.getUnqualifiedLookups().empty()) { -- Found.addAllDecls(SS.getUnqualifiedLookups()); -- Found.resolveKind(); -+ LookupQualifiedName(Found, LookupCtx); -+ -+ if (!ObjectType.isNull() && Found.empty()) { -+ // C++ [basic.lookup.classref]p4: -+ // If the id-expression in a class member access is a qualified-id of -+ // the form -+ // -+ // class-name-or-namespace-name::... -+ // -+ // the class-name-or-namespace-name following the . or -> operator is -+ // looked up both in the context of the entire postfix-expression and in -+ // the scope of the class of the object expression. If the name is found -+ // only in the scope of the class of the object expression, the name -+ // shall refer to a class-name. If the name is found only in the -+ // context of the entire postfix-expression, the name shall refer to a -+ // class-name or namespace-name. [...] -+ // -+ // Qualified name lookup into a class will not find a namespace-name, -+ // so we do not need to diagnose that case specifically. However, -+ // this qualified name lookup may find nothing. In that case, perform -+ // unqualified name lookup in the given scope (if available) or -+ // reconstruct the result from when name lookup was performed at template -+ // definition time. -+ if (S) -+ LookupName(Found, S); -+ else if (ScopeLookupResult) -+ Found.addDecl(ScopeLookupResult); -+ -+ ObjectTypeSearchedInScope = true; - } -- ObjectTypeSearchedInScope = true; -+ } else if (!isDependent) { -+ // Perform unqualified name lookup in the current scope. -+ LookupName(Found, S); - } - - if (Found.isAmbiguous()) - return true; - -+ // If we performed lookup into a dependent context and did not find anything, -+ // that's fine: just build a dependent nested-name-specifier. -+ if (Found.empty() && isDependent && -+ !(LookupCtx && LookupCtx->isRecord() && -+ (!cast(LookupCtx)->hasDefinition() || -+ !cast(LookupCtx)->hasAnyDependentBases()))) { -+ // Don't speculate if we're just trying to improve error recovery. -+ if (ErrorRecoveryLookup) -+ return true; -+ -+ // We were not able to compute the declaration context for a dependent -+ // base object type or prior nested-name-specifier, so this -+ // nested-name-specifier refers to an unknown specialization. Just build -+ // a dependent nested-name-specifier. -+ SS.Extend(Context, IdInfo.Identifier, IdInfo.IdentifierLoc, IdInfo.CCLoc); -+ return false; -+ } -+ - if (Found.empty() && !ErrorRecoveryLookup) { - // If identifier is not found as class-name-or-namespace-name, but is found - // as other entity, don't look for typos. - LookupResult R(*this, Found.getLookupNameInfo(), LookupOrdinaryName); -- LookupParsedName(R, S, &SS, ObjectType, -- /*AllowBuiltinCreation=*/false, EnteringContext); -- -+ if (LookupCtx) -+ LookupQualifiedName(R, LookupCtx); -+ else if (S && !isDependent) -+ LookupName(R, S); - if (!R.empty()) { - // Don't diagnose problems with this speculative lookup. - R.suppressDiagnostics(); -@@ -521,11 +539,6 @@ - } - } - -- DeclContext *LookupCtx = -- SS.isSet() -- ? computeDeclContext(SS, EnteringContext) -- : (!ObjectType.isNull() ? computeDeclContext(ObjectType) : nullptr); -- - if (Found.empty() && !ErrorRecoveryLookup && !getLangOpts().MSVCCompat) { - // We haven't found anything, and we're not recovering from a - // different kind of error, so look for typos. -@@ -581,14 +594,14 @@ - // scope, reconstruct the result from the template instantiation itself. - // - // Note that C++11 does *not* perform this redundant lookup. -- NamedDecl *OuterDecl = nullptr; -+ NamedDecl *OuterDecl; - if (S) { - LookupResult FoundOuter(*this, IdInfo.Identifier, IdInfo.IdentifierLoc, - LookupNestedNameSpecifierName); - LookupName(FoundOuter, S); - OuterDecl = FoundOuter.getAsSingle(); -- } else if (!SS.getUnqualifiedLookups().empty()) -- OuterDecl = SS.getUnqualifiedLookups().front().getDecl(); -+ } else -+ OuterDecl = ScopeLookupResult; - - if (isAcceptableNestedNameSpecifier(OuterDecl) && - OuterDecl->getCanonicalDecl() != SD->getCanonicalDecl() && -@@ -766,7 +779,7 @@ - return true; - - return BuildCXXNestedNameSpecifier(S, IdInfo, EnteringContext, SS, -- /*ErrorRecoveryLookup=*/false, -+ /*ScopeLookupResult=*/nullptr, false, - IsCorrectedToColon, OnlyNamespace); - } - -@@ -827,7 +840,7 @@ - return false; - - return !BuildCXXNestedNameSpecifier(S, IdInfo, EnteringContext, SS, -- /*ErrorRecoveryLookup=*/true); -+ /*ScopeLookupResult=*/nullptr, true); - } - - bool Sema::ActOnCXXNestedNameSpecifier(Scope *S, -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaDeclCXX.cpp b/clang/lib/Sema/SemaDeclCXX.cpp ---- a/clang/lib/Sema/SemaDeclCXX.cpp -+++ b/clang/lib/Sema/SemaDeclCXX.cpp -@@ -1275,11 +1275,9 @@ - if (UseMemberGet) { - // if [lookup of member get] finds at least one declaration, the - // initializer is e.get(). -- E = S.BuildMemberReferenceExpr(E.get(), DecompType, Loc, -- /*IsArrow=*/false, -- /*SS=*/CXXScopeSpec(), -- /*TemplateKWLoc=*/SourceLocation(), -- MemberGet, &Args, /*S=*/nullptr); -+ E = S.BuildMemberReferenceExpr(E.get(), DecompType, Loc, false, -+ CXXScopeSpec(), SourceLocation(), nullptr, -+ MemberGet, &Args, nullptr); - if (E.isInvalid()) - return true; - -@@ -4903,12 +4901,16 @@ - MemberLookup.addDecl(Indirect ? cast(Indirect) - : cast(Field), AS_public); - MemberLookup.resolveKind(); -- ExprResult CtorArg = SemaRef.BuildMemberReferenceExpr( -- MemberExprBase, ParamType, Loc, -- /*IsArrow=*/false, SS, -- /*TemplateKWLoc=*/SourceLocation(), MemberLookup, -- /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ ExprResult CtorArg -+ = SemaRef.BuildMemberReferenceExpr(MemberExprBase, -+ ParamType, Loc, -+ /*IsArrow=*/false, -+ SS, -+ /*TemplateKWLoc=*/SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ MemberLookup, -+ /*TemplateArgs=*/nullptr, -+ /*S*/nullptr); - if (CtorArg.isInvalid()) - return true; - -@@ -14334,10 +14336,8 @@ - public: - Expr *build(Sema &S, SourceLocation Loc) const override { - return assertNotNull(S.BuildMemberReferenceExpr( -- Builder.build(S, Loc), Type, Loc, IsArrow, SS, -- /*TemplateKwLoc=*/SourceLocation(), MemberLookup, -- /*TemplateArgs=*/nullptr, /*S=*/nullptr) -- .get()); -+ Builder.build(S, Loc), Type, Loc, IsArrow, SS, SourceLocation(), -+ nullptr, MemberLookup, nullptr, nullptr).get()); - } - - MemberBuilder(const ExprBuilder &Builder, QualType Type, bool IsArrow, -@@ -14543,11 +14543,13 @@ - Loc); - - // Create the reference to operator=. -- ExprResult OpEqualRef = S.BuildMemberReferenceExpr( -- To.build(S, Loc), T, Loc, /*IsArrow=*/false, SS, -- /*TemplateKWLoc=*/SourceLocation(), OpLookup, -- /*TemplateArgs=*/nullptr, /*S*/ nullptr, -- /*SuppressQualifierCheck=*/true); -+ ExprResult OpEqualRef -+ = S.BuildMemberReferenceExpr(To.build(S, Loc), T, Loc, /*IsArrow=*/false, -+ SS, /*TemplateKWLoc=*/SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ OpLookup, -+ /*TemplateArgs=*/nullptr, /*S*/nullptr, -+ /*SuppressQualifierCheck=*/true); - if (OpEqualRef.isInvalid()) - return StmtError(); - -@@ -17153,9 +17155,8 @@ - - auto BuildExpr = [&](LookupResult &LR) { - ExprResult Res = BuildMemberReferenceExpr( -- Message, Message->getType(), Message->getBeginLoc(), /*IsArrow=*/false, -- /*SS=*/CXXScopeSpec(), /*TemplateKWLoc=*/SourceLocation(), LR, -- /*TemplateArgs=*/nullptr, /*S=*/nullptr); -+ Message, Message->getType(), Message->getBeginLoc(), false, -+ CXXScopeSpec(), SourceLocation(), nullptr, LR, nullptr, nullptr); - if (Res.isInvalid()) - return ExprError(); - Res = BuildCallExpr(nullptr, Res.get(), Loc, std::nullopt, Loc, nullptr, -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp ---- a/clang/lib/Sema/SemaExpr.cpp -+++ b/clang/lib/Sema/SemaExpr.cpp -@@ -2624,7 +2624,7 @@ - return CXXDependentScopeMemberExpr::Create( - Context, /*This=*/nullptr, ThisType, /*IsArrow=*/true, - /*Op=*/SourceLocation(), NestedNameSpecifierLoc(), TemplateKWLoc, -- /*UnqualifiedLookups=*/std::nullopt, NameInfo, TemplateArgs); -+ /*FirstQualifierFoundInScope=*/nullptr, NameInfo, TemplateArgs); - } - - // Synthesize a fake NNS that points to the derived class. This will -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaExprMember.cpp b/clang/lib/Sema/SemaExprMember.cpp ---- a/clang/lib/Sema/SemaExprMember.cpp -+++ b/clang/lib/Sema/SemaExprMember.cpp -@@ -552,9 +552,11 @@ - } - - ExprResult --Sema::ActOnDependentMemberExpr(Expr *BaseExpr, QualType BaseType, bool IsArrow, -- SourceLocation OpLoc, const CXXScopeSpec &SS, -+Sema::ActOnDependentMemberExpr(Expr *BaseExpr, QualType BaseType, -+ bool IsArrow, SourceLocation OpLoc, -+ const CXXScopeSpec &SS, - SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, - const DeclarationNameInfo &NameInfo, - const TemplateArgumentListInfo *TemplateArgs) { - // Even in dependent contexts, try to diagnose base expressions with -@@ -588,8 +590,8 @@ - // must have pointer type, and the accessed type is the pointee. - return CXXDependentScopeMemberExpr::Create( - Context, BaseExpr, BaseType, IsArrow, OpLoc, -- SS.getWithLocInContext(Context), TemplateKWLoc, -- SS.getUnqualifiedLookups(), NameInfo, TemplateArgs); -+ SS.getWithLocInContext(Context), TemplateKWLoc, FirstQualifierInScope, -+ NameInfo, TemplateArgs); - } - - /// We know that the given qualified member reference points only to -@@ -765,9 +767,8 @@ - R.addDecl(ND); - R.resolveKind(); - return SemaRef.BuildMemberReferenceExpr( -- BaseExpr, BaseExpr->getType(), OpLoc, IsArrow, SS, -- /*TemplateKWLoc=*/SourceLocation(), R, /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ BaseExpr, BaseExpr->getType(), OpLoc, IsArrow, SS, SourceLocation(), -+ nullptr, R, nullptr, nullptr); - }, - Sema::CTK_ErrorRecovery, DC); - -@@ -783,7 +784,7 @@ - ExprResult Sema::BuildMemberReferenceExpr( - Expr *Base, QualType BaseType, SourceLocation OpLoc, bool IsArrow, - CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -- const DeclarationNameInfo &NameInfo, -+ NamedDecl *FirstQualifierInScope, const DeclarationNameInfo &NameInfo, - const TemplateArgumentListInfo *TemplateArgs, const Scope *S, - ActOnMemberAccessExtraArgs *ExtraArgs) { - LookupResult R(*this, NameInfo, LookupMemberName); -@@ -827,9 +828,10 @@ - if (SS.isInvalid()) - return ExprError(); - -- return BuildMemberReferenceExpr(Base, BaseType, OpLoc, IsArrow, SS, -- TemplateKWLoc, R, TemplateArgs, S, -- /*SuppressQualifierCheck=*/false, ExtraArgs); -+ return BuildMemberReferenceExpr(Base, BaseType, -+ OpLoc, IsArrow, SS, TemplateKWLoc, -+ FirstQualifierInScope, R, TemplateArgs, S, -+ false, ExtraArgs); - } - - ExprResult -@@ -967,11 +969,17 @@ - return false; - } - --ExprResult Sema::BuildMemberReferenceExpr( -- Expr *BaseExpr, QualType BaseExprType, SourceLocation OpLoc, bool IsArrow, -- const CXXScopeSpec &SS, SourceLocation TemplateKWLoc, LookupResult &R, -- const TemplateArgumentListInfo *TemplateArgs, const Scope *S, -- bool SuppressQualifierCheck, ActOnMemberAccessExtraArgs *ExtraArgs) { -+ExprResult -+Sema::BuildMemberReferenceExpr(Expr *BaseExpr, QualType BaseExprType, -+ SourceLocation OpLoc, bool IsArrow, -+ const CXXScopeSpec &SS, -+ SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, -+ LookupResult &R, -+ const TemplateArgumentListInfo *TemplateArgs, -+ const Scope *S, -+ bool SuppressQualifierCheck, -+ ActOnMemberAccessExtraArgs *ExtraArgs) { - assert(!SS.isInvalid() && "nested-name-specifier cannot be invalid"); - // If the member wasn't found in the current instantiation, or if the - // arrow operator was used with a dependent non-pointer object expression, -@@ -981,8 +989,8 @@ - (SS.isSet() ? SS.getScopeRep()->isDependent() - : BaseExprType->isDependentType()))) - return ActOnDependentMemberExpr(BaseExpr, BaseExprType, IsArrow, OpLoc, SS, -- TemplateKWLoc, R.getLookupNameInfo(), -- TemplateArgs); -+ TemplateKWLoc, FirstQualifierInScope, -+ R.getLookupNameInfo(), TemplateArgs); - - QualType BaseType = BaseExprType; - if (IsArrow) { -@@ -1187,9 +1195,9 @@ - - // Non-dependent member, but dependent template arguments. - if (!VDecl.get()) -- return ActOnDependentMemberExpr(BaseExpr, BaseExpr->getType(), IsArrow, -- OpLoc, SS, TemplateKWLoc, MemberNameInfo, -- TemplateArgs); -+ return ActOnDependentMemberExpr( -+ BaseExpr, BaseExpr->getType(), IsArrow, OpLoc, SS, TemplateKWLoc, -+ FirstQualifierInScope, MemberNameInfo, TemplateArgs); - - VarDecl *Var = cast(VDecl.get()); - if (!Var->getTemplateSpecializationKind()) -@@ -1755,16 +1763,15 @@ - const TemplateArgumentListInfo *TemplateArgs; - DecomposeUnqualifiedId(Id, TemplateArgsBuffer, - NameInfo, TemplateArgs); -- bool IsArrow = OpKind == tok::arrow; -+ -+ bool IsArrow = (OpKind == tok::arrow); - - if (getLangOpts().HLSL && IsArrow) - return ExprError(Diag(OpLoc, diag::err_hlsl_operator_unsupported) << 2); - -- UnresolvedSet<4> UnqualifiedLookups; -- if (SS.isValid() && -- LookupFirstQualifierInScope(S, SS.getScopeRep(), UnqualifiedLookups)) { -- SS.setUnqualifiedLookups(UnqualifiedLookups.pairs()); -- } -+ NamedDecl *FirstQualifierInScope -+ = (!SS.isSet() ? nullptr : FindFirstQualifierInScope(S, SS.getScopeRep())); -+ - // This is a postfix expression, so get rid of ParenListExprs. - ExprResult Result = MaybeConvertParenListExprToParenExpr(S, Base); - if (Result.isInvalid()) return ExprError(); -@@ -1772,8 +1779,8 @@ - - ActOnMemberAccessExtraArgs ExtraArgs = {S, Id, ObjCImpDecl}; - ExprResult Res = BuildMemberReferenceExpr( -- Base, Base->getType(), OpLoc, IsArrow, SS, TemplateKWLoc, NameInfo, -- TemplateArgs, S, &ExtraArgs); -+ Base, Base->getType(), OpLoc, IsArrow, SS, TemplateKWLoc, -+ FirstQualifierInScope, NameInfo, TemplateArgs, S, &ExtraArgs); - - if (!Res.isInvalid() && isa(Res.get())) - CheckMemberAccessOfNoDeref(cast(Res.get())); -@@ -1917,8 +1924,9 @@ - baseExpr = BuildCXXThisExpr(loc, ThisTy, /*IsImplicit=*/true); - } - -- return BuildMemberReferenceExpr(baseExpr, ThisTy, -- /*OpLoc=*/SourceLocation(), -- /*IsArrow=*/!getLangOpts().HLSL, SS, -- TemplateKWLoc, R, TemplateArgs, S); -+ return BuildMemberReferenceExpr( -+ baseExpr, ThisTy, -+ /*OpLoc=*/SourceLocation(), -+ /*IsArrow=*/!getLangOpts().HLSL, SS, TemplateKWLoc, -+ /*FirstQualifierInScope=*/nullptr, R, TemplateArgs, S); - } -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp ---- a/clang/lib/Sema/SemaOverload.cpp -+++ b/clang/lib/Sema/SemaOverload.cpp -@@ -16043,11 +16043,13 @@ - - CandidateSet->clear(OverloadCandidateSet::CSK_Normal); - if (!MemberLookup.empty()) { -- ExprResult MemberRef = BuildMemberReferenceExpr( -- Range, Range->getType(), Loc, -- /*IsPtr=*/false, /*SS=*/CXXScopeSpec(), -- /*TemplateKWLoc=*/SourceLocation(), MemberLookup, -- /*TemplateArgs=*/nullptr, S); -+ ExprResult MemberRef = -+ BuildMemberReferenceExpr(Range, Range->getType(), Loc, -+ /*IsPtr=*/false, CXXScopeSpec(), -+ /*TemplateKWLoc=*/SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ MemberLookup, -+ /*TemplateArgs=*/nullptr, S); - if (MemberRef.isInvalid()) { - *CallExpr = ExprError(); - return FRS_DiagnosticIssued; -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaStmtAsm.cpp b/clang/lib/Sema/SemaStmtAsm.cpp ---- a/clang/lib/Sema/SemaStmtAsm.cpp -+++ b/clang/lib/Sema/SemaStmtAsm.cpp -@@ -900,8 +900,7 @@ - return CXXDependentScopeMemberExpr::Create( - Context, E, T, /*IsArrow=*/false, AsmLoc, NestedNameSpecifierLoc(), - SourceLocation(), -- /*UnqualifiedLookups=*/std::nullopt, NameInfo, -- /*TemplateArgs=*/nullptr); -+ /*FirstQualifierFoundInScope=*/nullptr, NameInfo, /*TemplateArgs=*/nullptr); - } - - const RecordType *RT = T->getAs(); -@@ -924,9 +923,8 @@ - - // Make an Expr to thread through OpDecl. - ExprResult Result = BuildMemberReferenceExpr( -- E, E->getType(), AsmLoc, /*IsArrow=*/false, /*SS=*/CXXScopeSpec(), -- /*TemplateKWLoc*/ SourceLocation(), FieldResult, -- /*TemplateArgs=*/nullptr, /*S=*/nullptr); -+ E, E->getType(), AsmLoc, /*IsArrow=*/false, CXXScopeSpec(), -+ SourceLocation(), nullptr, FieldResult, nullptr, nullptr); - - return Result; - } -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp ---- a/clang/lib/Sema/SemaTemplate.cpp -+++ b/clang/lib/Sema/SemaTemplate.cpp -@@ -174,12 +174,15 @@ - return false; - } - --TemplateNameKind --Sema::isTemplateName(Scope *S, CXXScopeSpec &SS, bool hasTemplateKeyword, -- const UnqualifiedId &Name, ParsedType ObjectTypePtr, -- bool EnteringContext, TemplateTy &TemplateResult, -- bool &MemberOfUnknownSpecialization, bool Disambiguation, -- bool MayBeNNS) { -+TemplateNameKind Sema::isTemplateName(Scope *S, -+ CXXScopeSpec &SS, -+ bool hasTemplateKeyword, -+ const UnqualifiedId &Name, -+ ParsedType ObjectTypePtr, -+ bool EnteringContext, -+ TemplateTy &TemplateResult, -+ bool &MemberOfUnknownSpecialization, -+ bool Disambiguation) { - assert(getLangOpts().CPlusPlus && "No template names in C!"); - - DeclarationName TName; -@@ -210,9 +213,8 @@ - if (LookupTemplateName(R, S, SS, ObjectType, EnteringContext, - /*RequiredTemplate=*/SourceLocation(), - &AssumedTemplate, -- /*AllowTypoCorrection=*/!Disambiguation, MayBeNNS)) -+ /*AllowTypoCorrection=*/!Disambiguation)) - return TNK_Non_template; -- - MemberOfUnknownSpecialization = R.wasNotFoundInCurrentInstantiation(); - - if (AssumedTemplate != AssumedTemplateKind::None) { -@@ -378,7 +380,7 @@ - QualType ObjectType, bool EnteringContext, - RequiredTemplateKind RequiredTemplate, - AssumedTemplateKind *ATK, -- bool AllowTypoCorrection, bool MayBeNNS) { -+ bool AllowTypoCorrection) { - if (ATK) - *ATK = AssumedTemplateKind::None; - -@@ -387,89 +389,92 @@ - - Found.setTemplateNameLookup(true); - -- // Template names cannot appear inside an Objective-C class or object type -- // or a vector type. -- // -- // FIXME: This is wrong. For example: -- // -- // template using Vec = T __attribute__((ext_vector_type(4))); -- // Vec vi; -- // vi.Vec::~Vec(); -- // -- // ... should be accepted but we will not treat 'Vec' as a template name -- // here. The right thing to do would be to check if the name is a valid -- // vector component name, and look up a template name if not. And similarly -- // for lookups into Objective-C class and object types, where the same -- // problem can arise. -- if (!ObjectType.isNull() && (ObjectType->isVectorType() || -- ObjectType->isObjCObjectOrInterfaceType())) { -- Found.clear(); -- return false; -- } -+ // Determine where to perform name lookup -+ DeclContext *LookupCtx = nullptr; -+ bool IsDependent = false; -+ if (!ObjectType.isNull()) { -+ // This nested-name-specifier occurs in a member access expression, e.g., -+ // x->B::f, and we are looking into the type of the object. -+ assert(SS.isEmpty() && "ObjectType and scope specifier cannot coexist"); -+ LookupCtx = computeDeclContext(ObjectType); -+ IsDependent = !LookupCtx && ObjectType->isDependentType(); -+ assert((IsDependent || !ObjectType->isIncompleteType() || -+ !ObjectType->getAs() || -+ ObjectType->castAs()->isBeingDefined()) && -+ "Caller should have completed object type"); - -- LookupParsedName(Found, S, &SS, ObjectType, -- /*AllowBuiltinCreation=*/false, EnteringContext); -+ // Template names cannot appear inside an Objective-C class or object type -+ // or a vector type. -+ // -+ // FIXME: This is wrong. For example: -+ // -+ // template using Vec = T __attribute__((ext_vector_type(4))); -+ // Vec vi; -+ // vi.Vec::~Vec(); -+ // -+ // ... should be accepted but we will not treat 'Vec' as a template name -+ // here. The right thing to do would be to check if the name is a valid -+ // vector component name, and look up a template name if not. And similarly -+ // for lookups into Objective-C class and object types, where the same -+ // problem can arise. -+ if (ObjectType->isObjCObjectOrInterfaceType() || -+ ObjectType->isVectorType()) { -+ Found.clear(); -+ return false; -+ } -+ } else if (SS.isNotEmpty()) { -+ // This nested-name-specifier occurs after another nested-name-specifier, -+ // so long into the context associated with the prior nested-name-specifier. -+ LookupCtx = computeDeclContext(SS, EnteringContext); -+ IsDependent = !LookupCtx && isDependentScopeSpecifier(SS); - -- // C++ [basic.lookup.qual.general]p3: -- // [...] Unless otherwise specified, a qualified name undergoes qualified -- // name lookup in its lookup context from the point where it appears unless -- // the lookup context either is dependent and is not the current -- // instantiation or is not a class or class template. -- // -- // The lookup context is dependent and either: -- // - it is not the current instantiation, or -- // - it is the current instantiation, it has at least one dependent base -- // class, and qualified lookup found nothing. -- // -- // If this is a member-qualified name that is the terminal name of a -- // nested-name-specifier, we perform unqualified lookup and store the results -- // so we can replicate the lookup during instantiation. The results of the -- // unqualified loookup are *not* used to determine whether '<' is interpreted -- // as the delimiter of a template-argument-list. -- // -- // For example: -- // -- // template -- // struct A { -- // int x; -- // }; -- // -- // template -- // using B = A; -- // -- // template -- // void f(A a, A b) { -- // a.B::x; // error: missing 'template' before 'B' -- // b.B::x; // ok, lookup context is not dependent -- // } -- if (Found.wasNotFoundInCurrentInstantiation()) -- return false; -+ // The declaration context must be complete. -+ if (LookupCtx && RequireCompleteDeclContext(SS, LookupCtx)) -+ return true; -+ } - - bool ObjectTypeSearchedInScope = false; -- -- // C++ [basic.lookup.qual.general]p2: -- // A member-qualified name is the (unique) component name, if any, of -- // - an unqualified-id or -- // - a nested-name-specifier of the form type-name :: or namespace-name :: -- // in the id-expression of a class member access expression. -- // -- // C++ [basic.lookup.qual.general]p3: -- // [...] If nothing is found by qualified lookup for a member-qualified -- // name that is the terminal name of a nested-name-specifier and is not -- // dependent, it undergoes unqualified lookup. -- // -- // In 'x.A::B::y', 'A' will undergo unqualified lookup if qualified lookup -- // in the type of 'x' finds nothing. If the lookup context is dependent, -- // we perform the unqualified lookup in the template definition context -- // and store the results so we can replicate the lookup during instantiation. -- if (MayBeNNS && Found.empty() && !ObjectType.isNull()) { -- if (S) { -+ bool AllowFunctionTemplatesInLookup = true; -+ if (LookupCtx) { -+ // Perform "qualified" name lookup into the declaration context we -+ // computed, which is either the type of the base of a member access -+ // expression or the declaration context associated with a prior -+ // nested-name-specifier. -+ LookupQualifiedName(Found, LookupCtx); -+ -+ // FIXME: The C++ standard does not clearly specify what happens in the -+ // case where the object type is dependent, and implementations vary. In -+ // Clang, we treat a name after a . or -> as a template-name if lookup -+ // finds a non-dependent member or member of the current instantiation that -+ // is a type template, or finds no such members and lookup in the context -+ // of the postfix-expression finds a type template. In the latter case, the -+ // name is nonetheless dependent, and we may resolve it to a member of an -+ // unknown specialization when we come to instantiate the template. -+ IsDependent |= Found.wasNotFoundInCurrentInstantiation(); -+ } -+ -+ if (SS.isEmpty() && (ObjectType.isNull() || Found.empty())) { -+ // C++ [basic.lookup.classref]p1: -+ // In a class member access expression (5.2.5), if the . or -> token is -+ // immediately followed by an identifier followed by a <, the -+ // identifier must be looked up to determine whether the < is the -+ // beginning of a template argument list (14.2) or a less-than operator. -+ // The identifier is first looked up in the class of the object -+ // expression. If the identifier is not found, it is then looked up in -+ // the context of the entire postfix-expression and shall name a class -+ // template. -+ if (S) - LookupName(Found, S); -- } else if (!SS.getUnqualifiedLookups().empty()) { -- Found.addAllDecls(SS.getUnqualifiedLookups()); -- Found.resolveKind(); -+ -+ if (!ObjectType.isNull()) { -+ // FIXME: We should filter out all non-type templates here, particularly -+ // variable templates and concepts. But the exclusion of alias templates -+ // and template template parameters is a wording defect. -+ AllowFunctionTemplatesInLookup = false; -+ ObjectTypeSearchedInScope = true; - } -- ObjectTypeSearchedInScope = true; -+ -+ IsDependent |= Found.wasNotFoundInCurrentInstantiation(); - } - - if (Found.isAmbiguous()) -@@ -489,7 +494,7 @@ - getLangOpts().CPlusPlus20 && llvm::all_of(Found, [](NamedDecl *ND) { - return isa(ND->getUnderlyingDecl()); - }); -- if (AllFunctions || Found.empty()) { -+ if (AllFunctions || (Found.empty() && !IsDependent)) { - // If lookup found any functions, or if this is a name that can only be - // used for a function, then strongly assume this is a function - // template-id. -@@ -501,15 +506,11 @@ - } - } - -- if (Found.empty() && AllowTypoCorrection) { -+ if (Found.empty() && !IsDependent && AllowTypoCorrection) { - // If we did not find any names, and this is not a disambiguation, attempt - // to correct any typos. - DeclarationName Name = Found.getLookupName(); - Found.clear(); -- DeclContext *LookupCtx = -- SS.isSet() -- ? computeDeclContext(SS, EnteringContext) -- : (!ObjectType.isNull() ? computeDeclContext(ObjectType) : nullptr); - // Simple filter callback that, for keywords, only accepts the C++ *_cast - DefaultFilterCCC FilterCCC{}; - FilterCCC.WantTypeSpecifiers = false; -@@ -542,8 +543,13 @@ - - NamedDecl *ExampleLookupResult = - Found.empty() ? nullptr : Found.getRepresentativeDecl(); -- FilterAcceptableTemplateNames(Found); -+ FilterAcceptableTemplateNames(Found, AllowFunctionTemplatesInLookup); - if (Found.empty()) { -+ if (IsDependent) { -+ Found.setNotFoundInCurrentInstantiation(); -+ return false; -+ } -+ - // If a 'template' keyword was used, a lookup that finds only non-template - // names is an error. - if (ExampleLookupResult && RequiredTemplate) { -@@ -735,7 +741,7 @@ - /*IsArrow=*/!Context.getLangOpts().HLSL, - /*OperatorLoc=*/SourceLocation(), - /*QualifierLoc=*/NestedNameSpecifierLoc(), TemplateKWLoc, -- /*UnqualifiedLookups=*/std::nullopt, NameInfo, TemplateArgs); -+ /*FirstQualifierFoundInScope=*/nullptr, NameInfo, TemplateArgs); - } - return BuildDependentDeclRefExpr(SS, TemplateKWLoc, NameInfo, TemplateArgs); - } -@@ -5849,10 +5855,14 @@ - return BuildTemplateIdExpr(SS, TemplateKWLoc, R, /*ADL=*/false, TemplateArgs); - } - --TemplateNameKind Sema::ActOnTemplateName( -- Scope *S, CXXScopeSpec &SS, SourceLocation TemplateKWLoc, -- const UnqualifiedId &Name, ParsedType ObjectType, bool EnteringContext, -- TemplateTy &Result, bool AllowInjectedClassName, bool MayBeNNS) { -+TemplateNameKind Sema::ActOnTemplateName(Scope *S, -+ CXXScopeSpec &SS, -+ SourceLocation TemplateKWLoc, -+ const UnqualifiedId &Name, -+ ParsedType ObjectType, -+ bool EnteringContext, -+ TemplateTy &Result, -+ bool AllowInjectedClassName) { - if (TemplateKWLoc.isValid() && S && !S->getTemplateParamParent()) - Diag(TemplateKWLoc, - getLangOpts().CPlusPlus11 ? -@@ -5887,10 +5897,9 @@ - // "template" keyword is now permitted). We follow the C++0x - // rules, even in C++03 mode with a warning, retroactively applying the DR. - bool MemberOfUnknownSpecialization; -- TemplateNameKind TNK = -- isTemplateName(S, SS, TemplateKWLoc.isValid(), Name, ObjectType, -- EnteringContext, Result, MemberOfUnknownSpecialization, -- /*Disambiguation=*/false, MayBeNNS); -+ TemplateNameKind TNK = isTemplateName(S, SS, TemplateKWLoc.isValid(), Name, -+ ObjectType, EnteringContext, Result, -+ MemberOfUnknownSpecialization); - if (TNK != TNK_Non_template) { - // We resolved this to a (non-dependent) template name. Return it. - auto *LookupRD = dyn_cast_or_null(LookupCtx); -@@ -5929,8 +5938,7 @@ - ? RequiredTemplateKind(TemplateKWLoc) - : TemplateNameIsRequired; - if (!LookupTemplateName(R, S, SS, ObjectType.get(), EnteringContext, RTK, -- /*ATK=*/nullptr, /*AllowTypoCorrection=*/false, -- MayBeNNS) && -+ /*ATK=*/nullptr, /*AllowTypoCorrection=*/false) && - !R.isAmbiguous()) { - if (LookupCtx) - Diag(Name.getBeginLoc(), diag::err_no_member) -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp ---- a/clang/lib/Sema/SemaTemplateInstantiate.cpp -+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp -@@ -1515,11 +1515,12 @@ - NestedNameSpecifierLoc QualifierLoc, - QualType T); - -- TemplateName TransformTemplateName(CXXScopeSpec &SS, TemplateName Name, -- SourceLocation NameLoc, -- QualType ObjectType = QualType(), -- bool AllowInjectedClassName = false, -- bool MayBeNNS = false); -+ TemplateName -+ TransformTemplateName(CXXScopeSpec &SS, TemplateName Name, -+ SourceLocation NameLoc, -+ QualType ObjectType = QualType(), -+ NamedDecl *FirstQualifierInScope = nullptr, -+ bool AllowInjectedClassName = false); - - const CXXAssumeAttr *TransformCXXAssumeAttr(const CXXAssumeAttr *AA); - const LoopHintAttr *TransformLoopHintAttr(const LoopHintAttr *LH); -@@ -1951,7 +1952,8 @@ - - TemplateName TemplateInstantiator::TransformTemplateName( - CXXScopeSpec &SS, TemplateName Name, SourceLocation NameLoc, -- QualType ObjectType, bool AllowInjectedClassName, bool MayBeNNS) { -+ QualType ObjectType, NamedDecl *FirstQualifierInScope, -+ bool AllowInjectedClassName) { - if (TemplateTemplateParmDecl *TTP - = dyn_cast_or_null(Name.getAsTemplateDecl())) { - if (TTP->getDepth() < TemplateArgs.getNumLevels()) { -@@ -2023,7 +2025,8 @@ - } - - return inherited::TransformTemplateName(SS, Name, NameLoc, ObjectType, -- AllowInjectedClassName, MayBeNNS); -+ FirstQualifierInScope, -+ AllowInjectedClassName); - } - - ExprResult -diff -ruN --strip-trailing-cr a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h ---- a/clang/lib/Sema/TreeTransform.h -+++ b/clang/lib/Sema/TreeTransform.h -@@ -541,9 +541,10 @@ - /// By default, transforms all of the types and declarations within the - /// nested-name-specifier. Subclasses may override this function to provide - /// alternate behavior. -- NestedNameSpecifierLoc TransformNestedNameSpecifierLoc( -- NestedNameSpecifierLoc NNS, QualType ObjectType = QualType(), -- ArrayRef UnqualifiedLookups = std::nullopt); -+ NestedNameSpecifierLoc -+ TransformNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS, -+ QualType ObjectType = QualType(), -+ NamedDecl *FirstQualifierInScope = nullptr); - - /// Transform the given declaration name. - /// -@@ -584,11 +585,12 @@ - /// By default, transforms the template name by transforming the declarations - /// and nested-name-specifiers that occur within the template name. - /// Subclasses may override this function to provide alternate behavior. -- TemplateName TransformTemplateName(CXXScopeSpec &SS, TemplateName Name, -- SourceLocation NameLoc, -- QualType ObjectType = QualType(), -- bool AllowInjectedClassName = false, -- bool MayBeNNS = false); -+ TemplateName -+ TransformTemplateName(CXXScopeSpec &SS, TemplateName Name, -+ SourceLocation NameLoc, -+ QualType ObjectType = QualType(), -+ NamedDecl *FirstQualifierInScope = nullptr, -+ bool AllowInjectedClassName = false); - - /// Transform the given template argument. - /// -@@ -1138,8 +1140,8 @@ - CXXScopeSpec SS; - SS.Adopt(QualifierLoc); - TemplateName InstName = getDerived().RebuildTemplateName( -- SS, TemplateKWLoc, *Name, NameLoc, QualType(), AllowInjectedClassName, -- /*MayBeNNS=*/false); -+ SS, TemplateKWLoc, *Name, NameLoc, QualType(), nullptr, -+ AllowInjectedClassName); - - if (InstName.isNull()) - return QualType(); -@@ -1310,7 +1312,8 @@ - SourceLocation TemplateKWLoc, - const IdentifierInfo &Name, - SourceLocation NameLoc, QualType ObjectType, -- bool AllowInjectedClassName, bool MayBeNNS); -+ NamedDecl *FirstQualifierInScope, -+ bool AllowInjectedClassName); - - /// Build a new template name given a nested name specifier and the - /// overloaded operator name that is referred to as a template. -@@ -2846,14 +2849,15 @@ - /// - /// By default, performs semantic analysis to build the new expression. - /// Subclasses may override this routine to provide different behavior. -- ExprResult -- RebuildMemberExpr(Expr *Base, SourceLocation OpLoc, bool isArrow, -- NestedNameSpecifierLoc QualifierLoc, -- SourceLocation TemplateKWLoc, -- const DeclarationNameInfo &MemberNameInfo, -- ValueDecl *Member, NamedDecl *FoundDecl, -- const TemplateArgumentListInfo *ExplicitTemplateArgs, -- ArrayRef UnqualifiedLookups) { -+ ExprResult RebuildMemberExpr(Expr *Base, SourceLocation OpLoc, -+ bool isArrow, -+ NestedNameSpecifierLoc QualifierLoc, -+ SourceLocation TemplateKWLoc, -+ const DeclarationNameInfo &MemberNameInfo, -+ ValueDecl *Member, -+ NamedDecl *FoundDecl, -+ const TemplateArgumentListInfo *ExplicitTemplateArgs, -+ NamedDecl *FirstQualifierInScope) { - ExprResult BaseResult = getSema().PerformMemberExprBaseConversion(Base, - isArrow); - if (!Member->getDeclName()) { -@@ -2890,7 +2894,6 @@ - - CXXScopeSpec SS; - SS.Adopt(QualifierLoc); -- SS.setUnqualifiedLookups(UnqualifiedLookups); - - Base = BaseResult.get(); - if (Base->containsErrors()) -@@ -2923,9 +2926,10 @@ - } - - return getSema().BuildMemberReferenceExpr(Base, BaseType, OpLoc, isArrow, -- SS, TemplateKWLoc, R, -- ExplicitTemplateArgs, -- /*S=*/nullptr); -+ SS, TemplateKWLoc, -+ FirstQualifierInScope, -+ R, ExplicitTemplateArgs, -+ /*S*/nullptr); - } - - /// Build a new binary operator expression. -@@ -2998,9 +3002,10 @@ - CXXScopeSpec SS; - DeclarationNameInfo NameInfo(&Accessor, AccessorLoc); - return getSema().BuildMemberReferenceExpr( -- Base, Base->getType(), OpLoc, IsArrow, SS, -- /*TemplateKWLoc=*/SourceLocation(), NameInfo, -- /*TemplateArgs=*/nullptr, /*S=*/nullptr); -+ Base, Base->getType(), OpLoc, IsArrow, SS, SourceLocation(), -+ /*FirstQualifierInScope*/ nullptr, NameInfo, -+ /* TemplateArgs */ nullptr, -+ /*S*/ nullptr); - } - - /// Build a new initializer list expression. -@@ -3568,37 +3573,46 @@ - /// - /// By default, performs semantic analysis to build the new expression. - /// Subclasses may override this routine to provide different behavior. -- ExprResult RebuildCXXDependentScopeMemberExpr( -- Expr *BaseE, QualType BaseType, bool IsArrow, SourceLocation OperatorLoc, -- NestedNameSpecifierLoc QualifierLoc, SourceLocation TemplateKWLoc, -- ArrayRef UnqualifiedLookups, -- const DeclarationNameInfo &MemberNameInfo, -- const TemplateArgumentListInfo *TemplateArgs) { -+ ExprResult RebuildCXXDependentScopeMemberExpr(Expr *BaseE, -+ QualType BaseType, -+ bool IsArrow, -+ SourceLocation OperatorLoc, -+ NestedNameSpecifierLoc QualifierLoc, -+ SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, -+ const DeclarationNameInfo &MemberNameInfo, -+ const TemplateArgumentListInfo *TemplateArgs) { - CXXScopeSpec SS; - SS.Adopt(QualifierLoc); -- SS.setUnqualifiedLookups(UnqualifiedLookups); - -- return SemaRef.BuildMemberReferenceExpr( -- BaseE, BaseType, OperatorLoc, IsArrow, SS, TemplateKWLoc, -- MemberNameInfo, TemplateArgs, /*S=*/nullptr); -+ return SemaRef.BuildMemberReferenceExpr(BaseE, BaseType, -+ OperatorLoc, IsArrow, -+ SS, TemplateKWLoc, -+ FirstQualifierInScope, -+ MemberNameInfo, -+ TemplateArgs, /*S*/nullptr); - } - - /// Build a new member reference expression. - /// - /// By default, performs semantic analysis to build the new expression. - /// Subclasses may override this routine to provide different behavior. -- ExprResult RebuildUnresolvedMemberExpr( -- Expr *BaseE, QualType BaseType, SourceLocation OperatorLoc, bool IsArrow, -- NestedNameSpecifierLoc QualifierLoc, SourceLocation TemplateKWLoc, -- ArrayRef UnqualifiedLookups, LookupResult &R, -- const TemplateArgumentListInfo *TemplateArgs) { -+ ExprResult RebuildUnresolvedMemberExpr(Expr *BaseE, QualType BaseType, -+ SourceLocation OperatorLoc, -+ bool IsArrow, -+ NestedNameSpecifierLoc QualifierLoc, -+ SourceLocation TemplateKWLoc, -+ NamedDecl *FirstQualifierInScope, -+ LookupResult &R, -+ const TemplateArgumentListInfo *TemplateArgs) { - CXXScopeSpec SS; - SS.Adopt(QualifierLoc); -- SS.setUnqualifiedLookups(UnqualifiedLookups); - -- return SemaRef.BuildMemberReferenceExpr(BaseE, BaseType, OperatorLoc, -- IsArrow, SS, TemplateKWLoc, R, -- TemplateArgs, /*S=*/nullptr); -+ return SemaRef.BuildMemberReferenceExpr(BaseE, BaseType, -+ OperatorLoc, IsArrow, -+ SS, TemplateKWLoc, -+ FirstQualifierInScope, -+ R, TemplateArgs, /*S*/nullptr); - } - - /// Build a new noexcept expression. -@@ -3817,8 +3831,10 @@ - DeclarationNameInfo NameInfo(Ivar->getDeclName(), IvarLoc); - ExprResult Result = getSema().BuildMemberReferenceExpr( - BaseArg, BaseArg->getType(), -- /*FIXME:*/ IvarLoc, IsArrow, SS, /*TemplateKWLoc=*/SourceLocation(), -- NameInfo, /*TemplateArgs=*/nullptr, /*S=*/nullptr); -+ /*FIXME:*/ IvarLoc, IsArrow, SS, SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, NameInfo, -+ /*TemplateArgs=*/nullptr, -+ /*S=*/nullptr); - if (IsFreeIvar && Result.isUsable()) - cast(Result.get())->setIsFreeIvar(IsFreeIvar); - return Result; -@@ -3833,12 +3849,14 @@ - SourceLocation PropertyLoc) { - CXXScopeSpec SS; - DeclarationNameInfo NameInfo(Property->getDeclName(), PropertyLoc); -- return getSema().BuildMemberReferenceExpr( -- BaseArg, BaseArg->getType(), -- /*FIXME:*/ PropertyLoc, -- /*IsArrow=*/false, SS, /*TemplateKWLoc=*/SourceLocation(), NameInfo, -- /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ return getSema().BuildMemberReferenceExpr(BaseArg, BaseArg->getType(), -+ /*FIXME:*/PropertyLoc, -+ /*IsArrow=*/false, -+ SS, SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ NameInfo, -+ /*TemplateArgs=*/nullptr, -+ /*S=*/nullptr); - } - - /// Build a new Objective-C property reference expression. -@@ -3865,11 +3883,13 @@ - SourceLocation OpLoc, bool IsArrow) { - CXXScopeSpec SS; - DeclarationNameInfo NameInfo(&getSema().Context.Idents.get("isa"), IsaLoc); -- return getSema().BuildMemberReferenceExpr( -- BaseArg, BaseArg->getType(), OpLoc, IsArrow, SS, -- /*TemplateKWLoc=*/SourceLocation(), NameInfo, -- /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ return getSema().BuildMemberReferenceExpr(BaseArg, BaseArg->getType(), -+ OpLoc, IsArrow, -+ SS, SourceLocation(), -+ /*FirstQualifierInScope=*/nullptr, -+ NameInfo, -+ /*TemplateArgs=*/nullptr, -+ /*S=*/nullptr); - } - - /// Build a new shuffle vector expression. -@@ -4034,14 +4054,18 @@ - } - - private: -- TypeLoc TransformTypeInObjectScope(TypeLoc TL, QualType ObjectType, -+ TypeLoc TransformTypeInObjectScope(TypeLoc TL, -+ QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, - CXXScopeSpec &SS); - - TypeSourceInfo *TransformTypeInObjectScope(TypeSourceInfo *TSInfo, - QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, - CXXScopeSpec &SS); - - TypeSourceInfo *TransformTSIInObjectScope(TypeLoc TL, QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, - CXXScopeSpec &SS); - - QualType TransformDependentNameType(TypeLocBuilder &TLB, -@@ -4360,7 +4384,7 @@ - template - NestedNameSpecifierLoc TreeTransform::TransformNestedNameSpecifierLoc( - NestedNameSpecifierLoc NNS, QualType ObjectType, -- ArrayRef UnqualifiedLookups) { -+ NamedDecl *FirstQualifierInScope) { - SmallVector Qualifiers; - - auto insertNNS = [&Qualifiers](NestedNameSpecifierLoc NNS) { -@@ -4371,8 +4395,6 @@ - insertNNS(NNS); - - CXXScopeSpec SS; -- SS.setUnqualifiedLookups(UnqualifiedLookups); -- - while (!Qualifiers.empty()) { - NestedNameSpecifierLoc Q = Qualifiers.pop_back_val(); - NestedNameSpecifier *QNNS = Q.getNestedNameSpecifier(); -@@ -4382,9 +4404,8 @@ - Sema::NestedNameSpecInfo IdInfo(QNNS->getAsIdentifier(), - Q.getLocalBeginLoc(), Q.getLocalEndLoc(), - ObjectType); -- if (SemaRef.BuildCXXNestedNameSpecifier(/*Scope=*/nullptr, IdInfo, -- /*EnteringContext=*/false, SS, -- /*ErrorRecoveryLookup=*/false)) -+ if (SemaRef.BuildCXXNestedNameSpecifier(/*Scope=*/nullptr, IdInfo, false, -+ SS, FirstQualifierInScope, false)) - return NestedNameSpecifierLoc(); - break; - } -@@ -4422,7 +4443,8 @@ - - case NestedNameSpecifier::TypeSpecWithTemplate: - case NestedNameSpecifier::TypeSpec: { -- TypeLoc TL = TransformTypeInObjectScope(Q.getTypeLoc(), ObjectType, SS); -+ TypeLoc TL = TransformTypeInObjectScope(Q.getTypeLoc(), ObjectType, -+ FirstQualifierInScope, SS); - - if (!TL) - return NestedNameSpecifierLoc(); -@@ -4455,7 +4477,7 @@ - } - - // The qualifier-in-scope and object type only apply to the leftmost entity. -- SS.setUnqualifiedLookups(std::nullopt); -+ FirstQualifierInScope = nullptr; - ObjectType = QualType(); - } - -@@ -4538,10 +4560,14 @@ - llvm_unreachable("Unknown name kind."); - } - --template --TemplateName TreeTransform::TransformTemplateName( -- CXXScopeSpec &SS, TemplateName Name, SourceLocation NameLoc, -- QualType ObjectType, bool AllowInjectedClassName, bool MayBeNNS) { -+template -+TemplateName -+TreeTransform::TransformTemplateName(CXXScopeSpec &SS, -+ TemplateName Name, -+ SourceLocation NameLoc, -+ QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, -+ bool AllowInjectedClassName) { - if (QualifiedTemplateName *QTN = Name.getAsQualifiedTemplateName()) { - TemplateDecl *Template = QTN->getUnderlyingTemplate().getAsTemplateDecl(); - assert(Template && "qualified template name must refer to a template"); -@@ -4565,7 +4591,7 @@ - if (SS.getScopeRep()) { - // These apply to the scope specifier, not the template. - ObjectType = QualType(); -- SS.setUnqualifiedLookups(std::nullopt); -+ FirstQualifierInScope = nullptr; - } - - if (!getDerived().AlwaysRebuild() && -@@ -4577,9 +4603,13 @@ - SourceLocation TemplateKWLoc = NameLoc; - - if (DTN->isIdentifier()) { -- return getDerived().RebuildTemplateName( -- SS, TemplateKWLoc, *DTN->getIdentifier(), NameLoc, ObjectType, -- AllowInjectedClassName, MayBeNNS); -+ return getDerived().RebuildTemplateName(SS, -+ TemplateKWLoc, -+ *DTN->getIdentifier(), -+ NameLoc, -+ ObjectType, -+ FirstQualifierInScope, -+ AllowInjectedClassName); - } - - return getDerived().RebuildTemplateName(SS, TemplateKWLoc, -@@ -5123,31 +5153,39 @@ - return SemaRef.BuildQualifiedType(T, Loc, Quals); - } - --template --TypeLoc TreeTransform::TransformTypeInObjectScope(TypeLoc TL, -- QualType ObjectType, -- CXXScopeSpec &SS) { -+template -+TypeLoc -+TreeTransform::TransformTypeInObjectScope(TypeLoc TL, -+ QualType ObjectType, -+ NamedDecl *UnqualLookup, -+ CXXScopeSpec &SS) { - if (getDerived().AlreadyTransformed(TL.getType())) - return TL; - -- TypeSourceInfo *TSI = TransformTSIInObjectScope(TL, ObjectType, SS); -+ TypeSourceInfo *TSI = -+ TransformTSIInObjectScope(TL, ObjectType, UnqualLookup, SS); - if (TSI) - return TSI->getTypeLoc(); - return TypeLoc(); - } - --template --TypeSourceInfo *TreeTransform::TransformTypeInObjectScope( -- TypeSourceInfo *TSInfo, QualType ObjectType, CXXScopeSpec &SS) { -+template -+TypeSourceInfo * -+TreeTransform::TransformTypeInObjectScope(TypeSourceInfo *TSInfo, -+ QualType ObjectType, -+ NamedDecl *UnqualLookup, -+ CXXScopeSpec &SS) { - if (getDerived().AlreadyTransformed(TSInfo->getType())) - return TSInfo; - -- return TransformTSIInObjectScope(TSInfo->getTypeLoc(), ObjectType, SS); -+ return TransformTSIInObjectScope(TSInfo->getTypeLoc(), ObjectType, -+ UnqualLookup, SS); - } - - template - TypeSourceInfo *TreeTransform::TransformTSIInObjectScope( -- TypeLoc TL, QualType ObjectType, CXXScopeSpec &SS) { -+ TypeLoc TL, QualType ObjectType, NamedDecl *UnqualLookup, -+ CXXScopeSpec &SS) { - QualType T = TL.getType(); - assert(!getDerived().AlreadyTransformed(T)); - -@@ -5160,7 +5198,7 @@ - - TemplateName Template = getDerived().TransformTemplateName( - SS, SpecTL.getTypePtr()->getTemplateName(), SpecTL.getTemplateNameLoc(), -- ObjectType, /*AllowInjectedClassName=*/true, /*MayBeNNS=*/true); -+ ObjectType, UnqualLookup, /*AllowInjectedClassName*/true); - if (Template.isNull()) - return nullptr; - -@@ -5170,11 +5208,13 @@ - DependentTemplateSpecializationTypeLoc SpecTL = - TL.castAs(); - -- TemplateName Template = getDerived().RebuildTemplateName( -- SS, SpecTL.getTemplateKeywordLoc(), -- *SpecTL.getTypePtr()->getIdentifier(), SpecTL.getTemplateNameLoc(), -- ObjectType, -- /*AllowInjectedClassName=*/true, /*MayBeNNS=*/true); -+ TemplateName Template -+ = getDerived().RebuildTemplateName(SS, -+ SpecTL.getTemplateKeywordLoc(), -+ *SpecTL.getTypePtr()->getIdentifier(), -+ SpecTL.getTemplateNameLoc(), -+ ObjectType, UnqualLookup, -+ /*AllowInjectedClassName*/true); - if (Template.isNull()) - return nullptr; - -@@ -12318,8 +12358,7 @@ - // first-qualifier-in-scope here, just in case we had a dependent - // base (and therefore couldn't do the check) and a - // nested-name-qualifier (and therefore could do the lookup). -- ArrayRef UnqualifiedLookups; -- -+ NamedDecl *FirstQualifierInScope = nullptr; - DeclarationNameInfo MemberNameInfo = E->getMemberNameInfo(); - if (MemberNameInfo.getName()) { - MemberNameInfo = getDerived().TransformDeclarationNameInfo(MemberNameInfo); -@@ -12327,11 +12366,16 @@ - return ExprError(); - } - -- return getDerived().RebuildMemberExpr( -- Base.get(), FakeOperatorLoc, E->isArrow(), QualifierLoc, TemplateKWLoc, -- MemberNameInfo, Member, FoundDecl, -- (E->hasExplicitTemplateArgs() ? &TransArgs : nullptr), -- UnqualifiedLookups); -+ return getDerived().RebuildMemberExpr(Base.get(), FakeOperatorLoc, -+ E->isArrow(), -+ QualifierLoc, -+ TemplateKWLoc, -+ MemberNameInfo, -+ Member, -+ FoundDecl, -+ (E->hasExplicitTemplateArgs() -+ ? &TransArgs : nullptr), -+ FirstQualifierInScope); - } - - template -@@ -13458,8 +13502,9 @@ - - PseudoDestructorTypeStorage Destroyed; - if (E->getDestroyedTypeInfo()) { -- TypeSourceInfo *DestroyedTypeInfo = getDerived().TransformTypeInObjectScope( -- E->getDestroyedTypeInfo(), ObjectType, SS); -+ TypeSourceInfo *DestroyedTypeInfo -+ = getDerived().TransformTypeInObjectScope(E->getDestroyedTypeInfo(), -+ ObjectType, nullptr, SS); - if (!DestroyedTypeInfo) - return ExprError(); - Destroyed = DestroyedTypeInfo; -@@ -13485,7 +13530,7 @@ - if (E->getScopeTypeInfo()) { - CXXScopeSpec EmptySS; - ScopeTypeInfo = getDerived().TransformTypeInObjectScope( -- E->getScopeTypeInfo(), ObjectType, EmptySS); -+ E->getScopeTypeInfo(), ObjectType, nullptr, EmptySS); - if (!ScopeTypeInfo) - return ExprError(); - } -@@ -14746,17 +14791,19 @@ - ObjectType = BaseType->castAs()->getPointeeType(); - } - -- UnresolvedSet<4> UnqualifiedLookups; -- for (auto D : E->unqualified_lookups()) { -- if (NamedDecl *InstD = getDerived().TransformFirstQualifierInScope( -- D.getDecl(), E->getQualifierLoc().getBeginLoc())) -- UnqualifiedLookups.addDecl(InstD); -- } -+ // Transform the first part of the nested-name-specifier that qualifies -+ // the member name. -+ NamedDecl *FirstQualifierInScope -+ = getDerived().TransformFirstQualifierInScope( -+ E->getFirstQualifierFoundInScope(), -+ E->getQualifierLoc().getBeginLoc()); - - NestedNameSpecifierLoc QualifierLoc; - if (E->getQualifier()) { -- QualifierLoc = getDerived().TransformNestedNameSpecifierLoc( -- E->getQualifierLoc(), ObjectType, UnqualifiedLookups.pairs()); -+ QualifierLoc -+ = getDerived().TransformNestedNameSpecifierLoc(E->getQualifierLoc(), -+ ObjectType, -+ FirstQualifierInScope); - if (!QualifierLoc) - return ExprError(); - } -@@ -14775,16 +14822,23 @@ - if (!E->hasExplicitTemplateArgs()) { - // This is a reference to a member without an explicitly-specified - // template argument list. Optimize for this common case. -- if (!getDerived().AlwaysRebuild() && Base.get() == OldBase && -- BaseType == E->getBaseType() && QualifierLoc == E->getQualifierLoc() && -+ if (!getDerived().AlwaysRebuild() && -+ Base.get() == OldBase && -+ BaseType == E->getBaseType() && -+ QualifierLoc == E->getQualifierLoc() && - NameInfo.getName() == E->getMember() && -- UnqualifiedLookups.pairs() == E->unqualified_lookups()) -+ FirstQualifierInScope == E->getFirstQualifierFoundInScope()) - return E; - -- return getDerived().RebuildCXXDependentScopeMemberExpr( -- Base.get(), BaseType, E->isArrow(), E->getOperatorLoc(), QualifierLoc, -- TemplateKWLoc, UnqualifiedLookups.pairs(), NameInfo, -- /*TemplateArgs*/ nullptr); -+ return getDerived().RebuildCXXDependentScopeMemberExpr(Base.get(), -+ BaseType, -+ E->isArrow(), -+ E->getOperatorLoc(), -+ QualifierLoc, -+ TemplateKWLoc, -+ FirstQualifierInScope, -+ NameInfo, -+ /*TemplateArgs*/nullptr); - } - - TemplateArgumentListInfo TransArgs(E->getLAngleLoc(), E->getRAngleLoc()); -@@ -14793,9 +14847,15 @@ - TransArgs)) - return ExprError(); - -- return getDerived().RebuildCXXDependentScopeMemberExpr( -- Base.get(), BaseType, E->isArrow(), E->getOperatorLoc(), QualifierLoc, -- TemplateKWLoc, UnqualifiedLookups.pairs(), NameInfo, &TransArgs); -+ return getDerived().RebuildCXXDependentScopeMemberExpr(Base.get(), -+ BaseType, -+ E->isArrow(), -+ E->getOperatorLoc(), -+ QualifierLoc, -+ TemplateKWLoc, -+ FirstQualifierInScope, -+ NameInfo, -+ &TransArgs); - } - - template -@@ -14856,11 +14916,11 @@ - // first-qualifier-in-scope here, just in case we had a dependent - // base (and therefore couldn't do the check) and a - // nested-name-qualifier (and therefore could do the lookup). -- ArrayRef UnqualifiedLookups; -+ NamedDecl *FirstQualifierInScope = nullptr; - - return getDerived().RebuildUnresolvedMemberExpr( - Base.get(), BaseType, Old->getOperatorLoc(), Old->isArrow(), QualifierLoc, -- TemplateKWLoc, UnqualifiedLookups, R, -+ TemplateKWLoc, FirstQualifierInScope, R, - (Old->hasExplicitTemplateArgs() ? &TransArgs : nullptr)); - } - -@@ -16217,18 +16277,22 @@ - TemplateName(Template)); - } - --template --TemplateName TreeTransform::RebuildTemplateName( -- CXXScopeSpec &SS, SourceLocation TemplateKWLoc, const IdentifierInfo &Name, -- SourceLocation NameLoc, QualType ObjectType, bool AllowInjectedClassName, -- bool MayBeNNS) { -+template -+TemplateName -+TreeTransform::RebuildTemplateName(CXXScopeSpec &SS, -+ SourceLocation TemplateKWLoc, -+ const IdentifierInfo &Name, -+ SourceLocation NameLoc, -+ QualType ObjectType, -+ NamedDecl *FirstQualifierInScope, -+ bool AllowInjectedClassName) { - UnqualifiedId TemplateName; - TemplateName.setIdentifier(&Name, NameLoc); - Sema::TemplateTy Template; - getSema().ActOnTemplateName(/*Scope=*/nullptr, SS, TemplateKWLoc, - TemplateName, ParsedType::make(ObjectType), - /*EnteringContext=*/false, Template, -- AllowInjectedClassName, MayBeNNS); -+ AllowInjectedClassName); - return Template.get(); - } - -@@ -16376,10 +16440,13 @@ - } - - SourceLocation TemplateKWLoc; // FIXME: retrieve it from caller. -- return getSema().BuildMemberReferenceExpr( -- Base, BaseType, OperatorLoc, isArrow, SS, TemplateKWLoc, NameInfo, -- /*TemplateArgs=*/nullptr, -- /*S=*/nullptr); -+ return getSema().BuildMemberReferenceExpr(Base, BaseType, -+ OperatorLoc, isArrow, -+ SS, TemplateKWLoc, -+ /*FIXME: FirstQualifier*/ nullptr, -+ NameInfo, -+ /*TemplateArgs*/ nullptr, -+ /*S*/nullptr); - } - - template -diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp ---- a/clang/lib/Serialization/ASTReaderStmt.cpp -+++ b/clang/lib/Serialization/ASTReaderStmt.cpp -@@ -1993,43 +1993,42 @@ - CXXDependentScopeMemberExpr *E) { - VisitExpr(E); - -- CurrentUnpackingBits.emplace(Record.readInt()); -- bool HasQualifier = CurrentUnpackingBits->getNextBit(); -- bool HasTemplateInfo = CurrentUnpackingBits->getNextBit(); -- unsigned NumUnqualifiedLookups = Record.readInt(); - unsigned NumTemplateArgs = Record.readInt(); -- E->CXXDependentScopeMemberExprBits.HasQualifier = HasQualifier; -- E->CXXDependentScopeMemberExprBits.NumUnqualifiedLookups = -- NumUnqualifiedLookups; -- E->CXXDependentScopeMemberExprBits.HasTemplateKWAndArgsInfo = HasTemplateInfo; -+ CurrentUnpackingBits.emplace(Record.readInt()); -+ bool HasTemplateKWAndArgsInfo = CurrentUnpackingBits->getNextBit(); -+ bool HasFirstQualifierFoundInScope = CurrentUnpackingBits->getNextBit(); -+ -+ assert((HasTemplateKWAndArgsInfo == E->hasTemplateKWAndArgsInfo()) && -+ "Wrong HasTemplateKWAndArgsInfo!"); -+ assert( -+ (HasFirstQualifierFoundInScope == E->hasFirstQualifierFoundInScope()) && -+ "Wrong HasFirstQualifierFoundInScope!"); -+ -+ if (HasTemplateKWAndArgsInfo) -+ ReadTemplateKWAndArgsInfo( -+ *E->getTrailingObjects(), -+ E->getTrailingObjects(), NumTemplateArgs); -+ -+ assert((NumTemplateArgs == E->getNumTemplateArgs()) && -+ "Wrong NumTemplateArgs!"); - -- E->BaseType = Record.readType(); - E->CXXDependentScopeMemberExprBits.IsArrow = - CurrentUnpackingBits->getNextBit(); - -+ E->BaseType = Record.readType(); -+ E->QualifierLoc = Record.readNestedNameSpecifierLoc(); -+ // not ImplicitAccess - if (CurrentUnpackingBits->getNextBit()) - E->Base = Record.readSubExpr(); - else - E->Base = nullptr; - -- E->OperatorLoc = Record.readSourceLocation(); -- E->MemberNameInfo = Record.readDeclarationNameInfo(); -+ E->CXXDependentScopeMemberExprBits.OperatorLoc = readSourceLocation(); - -- if (HasQualifier) -- new (E->getTrailingObjects()) -- NestedNameSpecifierLoc(Record.readNestedNameSpecifierLoc()); -- -- for (unsigned I = 0; I != NumUnqualifiedLookups; ++I) { -- auto *FoundD = Record.readDeclAs(); -- auto AS = (AccessSpecifier)Record.readInt(); -- E->getTrailingObjects()[I] = -- DeclAccessPair::make(FoundD, AS); -- } -+ if (HasFirstQualifierFoundInScope) -+ *E->getTrailingObjects() = readDeclAs(); - -- if (HasTemplateInfo) -- ReadTemplateKWAndArgsInfo( -- *E->getTrailingObjects(), -- E->getTrailingObjects(), NumTemplateArgs); -+ E->MemberNameInfo = Record.readDeclarationNameInfo(); - } - - void -@@ -4076,16 +4075,16 @@ - break; - - case EXPR_CXX_DEPENDENT_SCOPE_MEMBER: { -+ unsigned NumTemplateArgs = Record[ASTStmtReader::NumExprFields]; - BitsUnpacker DependentScopeMemberBits( -- Record[ASTStmtReader::NumExprFields]); -- bool HasQualifier = DependentScopeMemberBits.getNextBit(); -- bool HasTemplateInfo = DependentScopeMemberBits.getNextBit(); -- unsigned NumUnqualifiedLookups = Record[ASTStmtReader::NumExprFields + 1]; -- unsigned NumTemplateArgs = Record[ASTStmtReader::NumExprFields + 2]; -+ Record[ASTStmtReader::NumExprFields + 1]); -+ bool HasTemplateKWAndArgsInfo = DependentScopeMemberBits.getNextBit(); - -+ bool HasFirstQualifierFoundInScope = -+ DependentScopeMemberBits.getNextBit(); - S = CXXDependentScopeMemberExpr::CreateEmpty( -- Context, HasQualifier, NumUnqualifiedLookups, HasTemplateInfo, -- NumTemplateArgs); -+ Context, HasTemplateKWAndArgsInfo, NumTemplateArgs, -+ HasFirstQualifierFoundInScope); - break; ++} // namespace targets ++} // namespace clang ++#endif // LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H +diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/OSTargets.h b/clang/lib/Basic/Targets/OSTargets.h +--- a/clang/lib/Basic/Targets/OSTargets.h ++++ b/clang/lib/Basic/Targets/OSTargets.h +@@ -841,6 +841,9 @@ + "i64:64-i128:128-n8:16:32:64-S128"); + } else if (Triple.getArch() == llvm::Triple::mipsel) { + // Handled on mips' setDataLayout. ++ } else { ++ assert(Triple.getArch() == llvm::Triple::le32); ++ this->resetDataLayout("e-p:32:32-i64:64"); + } + } + }; +diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets.cpp b/clang/lib/Basic/Targets.cpp +--- a/clang/lib/Basic/Targets.cpp ++++ b/clang/lib/Basic/Targets.cpp +@@ -23,6 +23,7 @@ + #include "Targets/DirectX.h" + #include "Targets/Hexagon.h" + #include "Targets/Lanai.h" ++#include "Targets/Le64.h" + #include "Targets/LoongArch.h" + #include "Targets/M68k.h" + #include "Targets/MSP430.h" +@@ -343,6 +344,17 @@ + return std::make_unique(Triple, Opts); } -diff -ruN --strip-trailing-cr a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp ---- a/clang/lib/Serialization/ASTWriterStmt.cpp -+++ b/clang/lib/Serialization/ASTWriterStmt.cpp -@@ -1988,41 +1988,34 @@ - CXXDependentScopeMemberExpr *E) { - VisitExpr(E); - -- bool HasQualifier = E->hasQualifier(); -- unsigned NumUnqualifiedLookups = E->getNumUnqualifiedLookups(); -- bool HasTemplateInfo = E->hasTemplateKWAndArgsInfo(); -- unsigned NumTemplateArgs = E->getNumTemplateArgs(); -- -- // Write these first for easy access when deserializing, as they affect the -- // size of the CXXDependentScopeMemberExpr. -+ // Don't emit anything here (or if you do you will have to update -+ // the corresponding deserialization function). -+ Record.push_back(E->getNumTemplateArgs()); - CurrentPackingBits.updateBits(); -- CurrentPackingBits.addBit(HasQualifier); -- CurrentPackingBits.addBit(HasTemplateInfo); -- Record.push_back(NumUnqualifiedLookups); -- Record.push_back(NumTemplateArgs); -+ CurrentPackingBits.addBit(E->hasTemplateKWAndArgsInfo()); -+ CurrentPackingBits.addBit(E->hasFirstQualifierFoundInScope()); ++ case llvm::Triple::le32: ++ switch (os) { ++ case llvm::Triple::NaCl: ++ return std::make_unique>(Triple, Opts); ++ default: ++ return nullptr; ++ } + -+ if (E->hasTemplateKWAndArgsInfo()) { -+ const ASTTemplateKWAndArgsInfo &ArgInfo = -+ *E->getTrailingObjects(); -+ AddTemplateKWAndArgsInfo(ArgInfo, -+ E->getTrailingObjects()); -+ } - -- Record.AddTypeRef(E->getBaseType()); - CurrentPackingBits.addBit(E->isArrow()); ++ case llvm::Triple::le64: ++ return std::make_unique(Triple, Opts); + -+ Record.AddTypeRef(E->getBaseType()); -+ Record.AddNestedNameSpecifierLoc(E->getQualifierLoc()); - CurrentPackingBits.addBit(!E->isImplicitAccess()); - if (!E->isImplicitAccess()) - Record.AddStmt(E->getBase()); + case llvm::Triple::ppc: + switch (os) { + case llvm::Triple::Linux: +diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp +--- a/clang/lib/CodeGen/CodeGenModule.cpp ++++ b/clang/lib/CodeGen/CodeGenModule.cpp +@@ -116,6 +116,8 @@ + default: + return createDefaultTargetCodeGenInfo(CGM); - Record.AddSourceLocation(E->getOperatorLoc()); ++ case llvm::Triple::le32: ++ return createPNaClTargetCodeGenInfo(CGM); + case llvm::Triple::m68k: + return createM68kTargetCodeGenInfo(CGM); + case llvm::Triple::mips: +diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp +--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp ++++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp +@@ -576,6 +576,13 @@ + return new XLCXXABI(CGM); -- Record.AddDeclarationNameInfo(E->MemberNameInfo); -- -- if (HasQualifier) -- Record.AddNestedNameSpecifierLoc(E->getQualifierLoc()); -- -- for (DeclAccessPair D : E->unqualified_lookups()) { -- Record.AddDeclRef(D.getDecl()); -- Record.push_back(D.getAccess()); -- } -- -- if (HasTemplateInfo) -- AddTemplateKWAndArgsInfo(*E->getTrailingObjects(), -- E->getTrailingObjects()); -+ if (E->hasFirstQualifierFoundInScope()) -+ Record.AddDeclRef(E->getFirstQualifierFoundInScope()); + case TargetCXXABI::GenericItanium: ++ if (CGM.getContext().getTargetInfo().getTriple().getArch() ++ == llvm::Triple::le32) { ++ // For PNaCl, use ARM-style method pointers so that PNaCl code ++ // does not assume anything about the alignment of function ++ // pointers. ++ return new ItaniumCXXABI(CGM, /*UseARMMethodPtrABI=*/true); ++ } + return new ItaniumCXXABI(CGM); -+ Record.AddDeclarationNameInfo(E->MemberNameInfo); - Code = serialization::EXPR_CXX_DEPENDENT_SCOPE_MEMBER; + case TargetCXXABI::Microsoft: +diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp +--- a/clang/lib/Driver/ToolChains/Clang.cpp ++++ b/clang/lib/Driver/ToolChains/Clang.cpp +@@ -3815,6 +3815,12 @@ + if (UseBuiltins) + A->render(Args, CmdArgs); + } ++ ++ // le32-specific flags: ++ // -fno-math-builtin: clang should not convert math builtins to intrinsics ++ // by default. ++ if (TC.getArch() == llvm::Triple::le32) ++ CmdArgs.push_back("-fno-math-builtin"); } + bool Driver::getDefaultModuleCachePath(SmallVectorImpl &Result) { diff -ruN --strip-trailing-cr a/clang/test/CodeGen/bitfield-access-pad.c b/clang/test/CodeGen/bitfield-access-pad.c --- a/clang/test/CodeGen/bitfield-access-pad.c +++ b/clang/test/CodeGen/bitfield-access-pad.c @@ -2862,540 +269,6 @@ diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bitfield-access-tail.cpp b // RUN: %clang_cc1 -triple=loongarch32-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s // RUN: %clang_cc1 -triple=nvptx-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s // RUN: %clang_cc1 -triple=riscv32 %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s -diff -ruN --strip-trailing-cr a/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1.cpp b/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1.cpp ---- a/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1.cpp -+++ b/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1.cpp -@@ -86,19 +86,15 @@ - - template T *end(T*); - -- struct X { }; -- struct Y { -- int end; -- }; -+ class X { }; - template - void Foo2() { - T it1; -- if (it1->end < it1->end) { } -+ if (it1->end < it1->end) { -+ } - - X *x; -- if (x->end < 7) { } // expected-error{{no member named 'end' in 'PR11856::X'}} -- -- Y *y; -- if (y->end < 7) { } -+ if (x->end < 7) { // expected-error{{no member named 'end' in 'PR11856::X'}} -+ } - } - } -diff -ruN --strip-trailing-cr a/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1-cxx11.cpp b/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1-cxx11.cpp ---- a/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1-cxx11.cpp -+++ b/clang/test/CXX/basic/basic.lookup/basic.lookup.classref/p1-cxx11.cpp -@@ -55,19 +55,15 @@ - - template T *end(T*); - -- struct X { }; -- struct Y { -- int end; -- }; -+ class X { }; - template - void Foo2() { - T it1; -- if (it1->end < it1->end) { } -+ if (it1->end < it1->end) { -+ } - - X *x; -- if (x->end < 7) { } // expected-error{{no member named 'end' in 'PR11856::X'}} -- -- Y *y; -- if (y->end < 7) { } -+ if (x->end < 7) { // expected-error{{no member named 'end' in 'PR11856::X'}} -+ } - } - } -diff -ruN --strip-trailing-cr a/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3.cpp b/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3.cpp ---- a/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3.cpp -+++ b/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3.cpp -@@ -1,98 +0,0 @@ --// RUN: %clang_cc1 -std=c++23 -Wno-unused %s -verify -- --namespace Unambiguous { -- struct A { -- int x; -- -- template -- using C = A; -- }; -- -- using B = A; -- -- template -- using D = A; -- -- using E = void; -- -- struct F : A { -- void non_template() { -- this->x; -- this->A::x; -- this->B::x; -- this->C::x; -- this->D::x; -- this->E::x; // expected-error {{'Unambiguous::E' (aka 'void') is not a class, namespace, or enumeration}} -- } -- }; -- -- template -- void not_instantiated(T t) { -- t.x; -- t.A::x; -- t.B::x; -- t.C::x; // expected-warning {{use 'template' keyword to treat 'C' as a dependent template name}} -- t.template C::x; -- t.D::x; // expected-warning {{use 'template' keyword to treat 'D' as a dependent template name}} -- t.template D::x; -- t.E::x; -- } -- -- template -- void instantiated_valid(T t) { -- t.x; -- t.A::x; -- t.B::x; -- t.template C::x; -- t.template D::x; -- t.E::x; -- } -- -- template -- void instantiated_invalid(T t) { -- t.x; -- t.A::x; -- t.B::x; // expected-error {{'Unambiguous::Invalid::B' (aka 'void') is not a class, namespace, or enumeration}} -- t.template C::x; -- t.template D::x; // expected-error {{'D' following the 'template' keyword does not refer to a template}} -- t.E::x; // expected-error {{'Unambiguous::E' (aka 'void') is not a class, namespace, or enumeration}} -- } -- -- struct Valid : A { -- using E = A; -- }; -- -- template void instantiated_valid(Valid); -- -- struct Invalid : A { -- using B = void; -- using D = A; // expected-note {{declared as a non-template here}} -- }; -- -- template void instantiated_invalid(Invalid); // expected-note {{in instantiation of}} --} // namespace Unambiguous -- --namespace Ambiguous { -- inline namespace N { -- struct A { }; // expected-note {{candidate found by name lookup is 'Ambiguous::N::A'}} -- } -- -- struct A { }; // expected-note {{candidate found by name lookup is 'Ambiguous::A'}} -- -- template -- void f(T t) { -- t.A::x; // expected-error {{reference to 'A' is ambiguous}} -- } -- -- struct B { -- using A = B; -- -- int x; -- }; -- -- struct C { }; -- -- template void f(B); -- template void f(C); // expected-note {{in instantiation of}} -- --} // namespace Ambiguous -diff -ruN --strip-trailing-cr a/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3-example3.cpp b/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3-example3.cpp ---- a/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3-example3.cpp -+++ b/clang/test/CXX/basic/basic.lookup/basic.lookup.qual/basic.lookup.qual.general/p3-example3.cpp -@@ -1,27 +0,0 @@ --// RUN: %clang_cc1 -std=c++23 %s -verify -- --int f(); -- --struct A { -- int B, C; // expected-note {{declared as a non-template here}} -- template using D = void; -- using T = void; -- void f(); --}; -- --using B = A; --template using C = A; --template using D = A; --template using X = A; -- --template --void g(T *p) { -- p->X<0>::f(); // expected-error {{no member named 'X' in 'A'}} -- p->template X<0>::f(); -- p->B::f(); -- p->template C<0>::f(); // expected-error {{'C' following the 'template' keyword does not refer to a template}} -- p->template D<0>::f(); // expected-error {{type 'template D<0>' (aka 'void') cannot be used prior to '::' because it has no members}} -- p->T::f(); // expected-error {{'A::T' (aka 'void') is not a class, namespace, or enumeration}} --} -- --template void g(A*); // expected-note {{in instantiation of}} -diff -ruN --strip-trailing-cr a/clang/test/CXX/class.derived/class.member.lookup/p8.cpp b/clang/test/CXX/class.derived/class.member.lookup/p8.cpp ---- a/clang/test/CXX/class.derived/class.member.lookup/p8.cpp -+++ b/clang/test/CXX/class.derived/class.member.lookup/p8.cpp -@@ -47,8 +47,8 @@ - void DerivedT::Inner() { - Derived1T::Foo(); - Derived2T::Member = 42; -- this->Derived1T::Foo(); // expected-warning{{use 'template' keyword to treat 'Derived1T' as a dependent template name}} -- this->Derived2T::Member = 42; // expected-warning{{use 'template' keyword to treat 'Derived2T' as a dependent template name}} -+ this->Derived1T::Foo(); -+ this->Derived2T::Member = 42; - this->Foo(); // expected-error{{non-static member 'Foo' found in multiple base-class subobjects of type 'BaseT'}} - } - -diff -ruN --strip-trailing-cr a/clang/test/CXX/drs/cwg1xx.cpp b/clang/test/CXX/drs/cwg1xx.cpp ---- a/clang/test/CXX/drs/cwg1xx.cpp -+++ b/clang/test/CXX/drs/cwg1xx.cpp -@@ -615,8 +615,10 @@ - // cxx98-note@#cwg141-S {{lookup from the current scope refers here}} - // expected-error@#cwg141-a {{no member named 'n' in 'cwg141::A::S'; did you mean '::cwg141::S::n'?}} - // expected-note@#cwg141-S {{'::cwg141::S::n' declared here}} -+ // FIXME: we issue a useful diagnostic first, then some bogus ones. - b.f(); - // expected-error@-1 {{no member named 'f' in 'cwg141::B'}} -+ // expected-error@-2 +{{}} - (void)b.S::n; - } - template struct C { -@@ -626,12 +628,10 @@ - // expected-error@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} - } - void h() { -- (void)t.S::n; -- // expected-error@-1 {{use 'template' keyword to treat 'S' as a dependent template name}} -+ (void)t.S::n; // ok - } - void i() { -- (void)t.S(); -- // expected-error@-1 {{use 'template' keyword to treat 'S' as a dependent template name}} -+ (void)t.S(); // ok! - } - }; - void h() { C().h(); } // ok -diff -ruN --strip-trailing-cr a/clang/test/CXX/temp/temp.names/p3-23.cpp b/clang/test/CXX/temp/temp.names/p3-23.cpp ---- a/clang/test/CXX/temp/temp.names/p3-23.cpp -+++ b/clang/test/CXX/temp/temp.names/p3-23.cpp -@@ -1,237 +0,0 @@ --// RUN: %clang_cc1 -std=c++23 -Wno-unused %s -verify -- --namespace FoundNothing { -- template -- void f0(T &t) { -- t.x<0; -- t.x<0>; // expected-error {{expected expression}} -- t.x<0>1; -- } -- -- template -- struct A { -- void f1() { -- this->x<0; // expected-error {{no member named 'x' in 'A'}} -- this->x<0>; // expected-error {{no member named 'x' in 'A'}} -- // expected-error@-1 {{expected expression}} -- this->x<0>1; // expected-error {{no member named 'x' in 'A'}} -- } -- }; --} // namespace FoundNothing -- --namespace FoundSingleNonTemplate { -- void f0(); -- -- struct A0; -- -- template -- void g0(T &t) { -- t.f0<0; -- t.f0<0>; // expected-error {{expected expression}} -- t.f0<0>1; -- -- t.A0<0; -- t.A0<0>; // expected-error {{expected expression}} -- t.A0<0>1; -- } -- -- template -- struct B { -- void f1(); -- -- struct A1; // expected-note 3{{member 'A1' declared here}} -- -- void g1() { -- this->f0<0; // expected-error {{no member named 'f0' in 'B'}} -- this->f0<0>; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{expected expression}} -- this->f0<0>1; // expected-error {{no member named 'f0' in 'B'}} -- -- this->A0<0; // expected-error {{no member named 'A0' in 'B'}} -- this->A0<0>; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{expected expression}} -- this->A0<0>1; // expected-error {{no member named 'A0' in 'B'}} -- -- this->f1<0; // expected-error {{reference to non-static member function must be called}} -- this->f1<0>; // expected-error {{reference to non-static member function must be called}} -- // expected-error@-1 {{expected expression}} -- this->f1<0>1; // expected-error {{reference to non-static member function must be called}} -- -- this->A1<0; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- this->A1<0>; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- // expected-error@-1 {{expected expression}} -- this->A1<0>1; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- } -- }; --} // namespace FoundSingleNonTemplate -- --namespace FoundSingleTemplate { -- template -- void f0(); -- -- template -- struct A0; -- -- template -- void g0(T &t) { -- t.f0<0; -- t.f0<0>; // expected-error {{expected expression}} -- t.f0<0>1; -- -- t.A0<0; -- t.A0<0>; // expected-error {{expected expression}} -- t.A0<0>1; -- } -- -- template -- struct B { -- template -- void f1(); // expected-note 2{{possible target for call}} -- -- template -- struct A1; // expected-note 2{{member 'A1' declared here}} -- -- void g1() { -- this->f0<0; // expected-error {{no member named 'f0' in 'B'}} -- this->f0<0>; // expected-error {{no member named 'f0' in 'B'}} -- this->f0<0>1; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{expected ';' after expression}} -- -- this->A0<0; // expected-error {{no member named 'A0' in 'B'}} -- this->A0<0>; // expected-error {{no member named 'A0' in 'B'}} -- this->A0<0>1; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{expected ';' after expression}} -- -- -- this->f1<0; // expected-error {{expected '>'}} -- // expected-note@-1 {{to match this '<'}} -- this->f1<0>; // expected-error {{reference to non-static member function must be called}} -- this->f1<0>1; // expected-error {{reference to non-static member function must be called}} -- // expected-error@-1 {{expected ';' after expression}} -- -- this->A1<0; // expected-error {{expected '>'}} -- // expected-note@-1 {{to match this '<'}} -- this->A1<0>; // expected-error {{cannot refer to member 'A1' in 'B' with '->'}} -- this->A1<0>1; // expected-error {{cannot refer to member 'A1' in 'B' with '->'}} -- // expected-error@-1 {{expected ';' after expression}} -- } -- }; --} // namespace FoundSingleTemplate -- --namespace FoundAmbiguousNonTemplate { -- inline namespace N { -- int f0; -- -- struct A0; -- } // namespace N -- -- void f0(); -- -- struct A0; -- -- template -- void g0(T &t) { -- t.f0<0; -- t.f0<0>; // expected-error {{expected expression}} -- t.f0<0>1; -- -- t.A0<0; -- t.A0<0>; // expected-error {{expected expression}} -- t.A0<0>1; -- } -- -- template -- struct B { -- void f1(); -- -- struct A1; // expected-note 3{{member 'A1' declared here}} -- -- void g1() { -- this->f0<0; // expected-error {{no member named 'f0' in 'B'}} -- this->f0<0>; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{expected expression}} -- this->f0<0>1; // expected-error {{no member named 'f0' in 'B'}} -- -- this->A0<0; // expected-error {{no member named 'A0' in 'B'}} -- this->A0<0>; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{expected expression}} -- this->A0<0>1; // expected-error {{no member named 'A0' in 'B'}} -- -- this->f1<0; // expected-error {{reference to non-static member function must be called}} -- this->f1<0>; // expected-error {{reference to non-static member function must be called}} -- // expected-error@-1 {{expected expression}} -- this->f1<0>1; // expected-error {{reference to non-static member function must be called}} -- -- this->A1<0; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- this->A1<0>; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- // expected-error@-1 {{expected expression}} -- this->A1<0>1; // expected-error {{cannot refer to type member 'A1' in 'B' with '->'}} -- } -- }; --} // namespace FoundAmbiguousNonTemplates -- --namespace FoundAmbiguousTemplate { -- inline namespace N { -- template -- int f0; // expected-note 3{{candidate found by name lookup is 'FoundAmbiguousTemplate::N::f0'}} -- -- template -- struct A0; // expected-note 3{{candidate found by name lookup is 'FoundAmbiguousTemplate::N::A0'}} -- } // namespace N -- -- template -- void f0(); // expected-note 3{{candidate found by name lookup is 'FoundAmbiguousTemplate::f0'}} -- -- template -- struct A0; // expected-note 3{{candidate found by name lookup is 'FoundAmbiguousTemplate::A0'}} -- -- template -- void g0(T &t) { -- t.f0<0; -- t.f0<0>; // expected-error {{expected expression}} -- t.f0<0>1; -- -- t.A0<0; -- t.A0<0>; // expected-error {{expected expression}} -- t.A0<0>1; -- } -- -- template -- struct B { -- template -- void f1(); // expected-note 2{{possible target for call}} -- -- template -- struct A1; // expected-note 2{{member 'A1' declared here}} -- -- void g1() { -- this->f0<0; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{reference to 'f0' is ambiguous}} -- this->f0<0>; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{reference to 'f0' is ambiguous}} -- this->f0<0>1; // expected-error {{no member named 'f0' in 'B'}} -- // expected-error@-1 {{expected ';' after expression}} -- // expected-error@-2 {{reference to 'f0' is ambiguous}} -- -- this->A0<0; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{reference to 'A0' is ambiguous}} -- this->A0<0>; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{reference to 'A0' is ambiguous}} -- this->A0<0>1; // expected-error {{no member named 'A0' in 'B'}} -- // expected-error@-1 {{expected ';' after expression}} -- // expected-error@-2 {{reference to 'A0' is ambiguous}} -- -- this->f1<0; // expected-error {{expected '>'}} -- // expected-note@-1 {{to match this '<'}} -- this->f1<0>; // expected-error {{reference to non-static member function must be called}} -- this->f1<0>1; // expected-error {{reference to non-static member function must be called}} -- // expected-error@-1 {{expected ';' after expression}} -- -- this->A1<0; // expected-error {{expected '>'}} -- // expected-note@-1 {{to match this '<'}} -- this->A1<0>; // expected-error {{cannot refer to member 'A1' in 'B' with '->'}} -- this->A1<0>1; // expected-error {{cannot refer to member 'A1' in 'B' with '->'}} -- // expected-error@-1 {{expected ';' after expression}} -- } -- }; --} // namespace FoundAmbiguousTemplate -diff -ruN --strip-trailing-cr a/clang/test/CXX/temp/temp.res/p3.cpp b/clang/test/CXX/temp/temp.res/p3.cpp ---- a/clang/test/CXX/temp/temp.res/p3.cpp -+++ b/clang/test/CXX/temp/temp.res/p3.cpp -@@ -30,6 +30,6 @@ - template template struct A::B { - friend A::C f6(); // ok, same as 'friend T f6();' - -- friend A::C f7(); // expected-warning {{use 'template' keyword to treat 'C' as a dependent template name}} expected-warning {{missing 'typename'}} -+ friend A::C f7(); // expected-error {{use 'template' keyword to treat 'C' as a dependent template name}} expected-warning {{missing 'typename'}} - friend A::template C f8(); // expected-warning {{missing 'typename'}} - }; -diff -ruN --strip-trailing-cr a/clang/test/FixIt/fixit.cpp b/clang/test/FixIt/fixit.cpp ---- a/clang/test/FixIt/fixit.cpp -+++ b/clang/test/FixIt/fixit.cpp -@@ -158,12 +158,12 @@ - - template - class F2 { -- typename F1:: /*template*/ Iterator<0> Mypos; // expected-warning {{use 'template' keyword to treat 'Iterator' as a dependent template name}} -+ typename F1:: /*template*/ Iterator<0> Mypos; // expected-error {{use 'template' keyword to treat 'Iterator' as a dependent template name}} - }; - - template - void f(){ -- typename F1:: /*template*/ Iterator<0> Mypos; // expected-warning {{use 'template' keyword to treat 'Iterator' as a dependent template name}} -+ typename F1:: /*template*/ Iterator<0> Mypos; // expected-error {{use 'template' keyword to treat 'Iterator' as a dependent template name}} - } - - // Tests for &/* fixits -diff -ruN --strip-trailing-cr a/clang/test/Misc/warning-flags.c b/clang/test/Misc/warning-flags.c ---- a/clang/test/Misc/warning-flags.c -+++ b/clang/test/Misc/warning-flags.c -@@ -18,7 +18,7 @@ - - The list of warnings below should NEVER grow. It should gradually shrink to 0. - --CHECK: Warnings without flags (64): -+CHECK: Warnings without flags (65): - - CHECK-NEXT: ext_expected_semi_decl_list - CHECK-NEXT: ext_missing_whitespace_after_macro_name -@@ -61,6 +61,7 @@ - CHECK-NEXT: warn_maynot_respond - CHECK-NEXT: warn_method_param_redefinition - CHECK-NEXT: warn_missing_case_for_condition -+CHECK-NEXT: warn_missing_dependent_template_keyword - CHECK-NEXT: warn_missing_whitespace_after_macro_name - CHECK-NEXT: warn_mt_message - CHECK-NEXT: warn_no_constructor_for_refconst -diff -ruN --strip-trailing-cr a/clang/test/Parser/cxx2a-concepts-requires-expr.cpp b/clang/test/Parser/cxx2a-concepts-requires-expr.cpp ---- a/clang/test/Parser/cxx2a-concepts-requires-expr.cpp -+++ b/clang/test/Parser/cxx2a-concepts-requires-expr.cpp -@@ -78,7 +78,7 @@ - - template - bool r23 = requires { typename identity::temp; }; --// expected-warning@-1 {{use 'template' keyword to treat 'temp' as a dependent template name}} -+// expected-error@-1 {{use 'template' keyword to treat 'temp' as a dependent template name}} - - template - bool r24 = requires { diff -ruN --strip-trailing-cr a/clang/test/Preprocessor/predefined-macros-no-warnings.c b/clang/test/Preprocessor/predefined-macros-no-warnings.c --- a/clang/test/Preprocessor/predefined-macros-no-warnings.c +++ b/clang/test/Preprocessor/predefined-macros-no-warnings.c @@ -3408,288 +281,67 @@ diff -ruN --strip-trailing-cr a/clang/test/Preprocessor/predefined-macros-no-war // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc-freebsd // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc-netbsd -diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/cxx0x-noexcept-expression.cpp b/clang/test/SemaCXX/cxx0x-noexcept-expression.cpp ---- a/clang/test/SemaCXX/cxx0x-noexcept-expression.cpp -+++ b/clang/test/SemaCXX/cxx0x-noexcept-expression.cpp -@@ -127,7 +127,7 @@ - // `dependent` should be type-dependent because the noexcept-expression should be value-dependent - // (it is true if T is int*, false if T is Polymorphic* for example) - dependent.f(); // This should need to be `.template f` to parse as a template -- // expected-warning@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} -+ // expected-error@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} - } - template - void f2() { -@@ -135,14 +135,14 @@ - // X when T...[0] is a type with some operator&& which returns int* - // X when sizeof...(T) == 0 - dependent.f(); -- // expected-warning@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} -+ // expected-error@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} - } - template - void f3() { - X(nullptr)))> dependent; - // X when T is int, X when T is Polymorphic - dependent.f(); -- // expected-warning@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} -+ // expected-error@-1 {{use 'template' keyword to treat 'f' as a dependent template name}} +diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h +--- a/llvm/include/llvm/IR/PatternMatch.h ++++ b/llvm/include/llvm/IR/PatternMatch.h +@@ -1550,27 +1550,23 @@ + template + struct CmpClass_match { +- PredicateTy *Predicate; ++ PredicateTy &Predicate; + LHS_t L; + RHS_t R; + + // The evaluation order is always stable, regardless of Commutability. + // The LHS is always matched first. + CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) +- : Predicate(&Pred), L(LHS), R(RHS) {} +- CmpClass_match(const LHS_t &LHS, const RHS_t &RHS) +- : Predicate(nullptr), L(LHS), R(RHS) {} ++ : Predicate(Pred), L(LHS), R(RHS) {} + + template bool match(OpTy *V) { + if (auto *I = dyn_cast(V)) { + if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { +- if (Predicate) +- *Predicate = I->getPredicate(); ++ Predicate = I->getPredicate(); + return true; + } else if (Commutable && L.match(I->getOperand(1)) && + R.match(I->getOperand(0))) { +- if (Predicate) +- *Predicate = I->getSwappedPredicate(); ++ Predicate = I->getSwappedPredicate(); + return true; + } + } +@@ -1599,19 +1595,22 @@ + template + inline CmpClass_match + m_Cmp(const LHS &L, const RHS &R) { +- return CmpClass_match(L, R); ++ CmpInst::Predicate Unused; ++ return CmpClass_match(Unused, L, R); } - template - void f4() { -diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/pseudo-destructors.cpp b/clang/test/SemaCXX/pseudo-destructors.cpp ---- a/clang/test/SemaCXX/pseudo-destructors.cpp -+++ b/clang/test/SemaCXX/pseudo-destructors.cpp -@@ -22,21 +22,21 @@ - void f(A* a, Foo *f, int *i, double *d, int ii) { - a->~A(); - a->A::~A(); -- -+ - a->~foo(); // expected-error{{undeclared identifier 'foo' in destructor name}} -- -+ - a->~Bar(); // expected-error{{destructor type 'Bar' (aka 'Foo') in object destruction expression does not match the type 'A' of the object being destroyed}} -- -+ - f->~Bar(); - f->~Foo(); - i->~Bar(); // expected-error{{does not match}} -- -+ - g().~Bar(); // expected-error{{non-scalar}} -- -+ - f->::~Bar(); // expected-error {{not a structure or union}} - f->::Bar::~Bar(); - f->N::~Wibble(); // expected-error{{'N' does not refer to a type}} expected-error{{'Wibble' does not refer to a type}} -- -+ - f->Bar::~Bar(17, 42); // expected-error{{cannot have any arguments}} - - i->~Integer(); -@@ -148,12 +148,12 @@ - namespace Template { - template struct Y {}; - template using G = Y; -- template void f(T *p) { p->~G(); } // expected-error {{no member named 'G'}} -+ template void f(T *p) { p->~G(); } // expected-error {{no member named '~Y'}} - void h1(Y *p) { p->~G(); } -- void h2(Y *p) { f(p); } // expected-note {{instantiation of}} -+ void h2(Y *p) { f(p); } - namespace N { template struct G {}; } - void h3(N::G *p) { p->~G(); } -- void h4(N::G *p) { f(p); } -+ void h4(N::G *p) { f(p); } // expected-note {{instantiation of}} - } - - namespace TemplateUndeclared { -diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/static-assert-cxx17.cpp b/clang/test/SemaCXX/static-assert-cxx17.cpp ---- a/clang/test/SemaCXX/static-assert-cxx17.cpp -+++ b/clang/test/SemaCXX/static-assert-cxx17.cpp -@@ -96,7 +96,7 @@ - // expected-error@-1{{static assertion failed due to requirement 'static_cast *>(nullptr)'}} - static_assert((const X[]){} == nullptr); - // expected-error@-1{{static assertion failed due to requirement '(const X[0]){} == nullptr'}} -- static_assert(sizeof(X().template X::~X())>) == 0); -+ static_assert(sizeof(X().X::~X())>) == 0); - // expected-error@-1{{static assertion failed due to requirement 'sizeof(X) == 0'}} \ - // expected-note@-1 {{evaluates to '8 == 0'}} - static_assert(constexpr_return_false()); -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-base-classes.cpp b/clang/test/SemaTemplate/dependent-base-classes.cpp ---- a/clang/test/SemaTemplate/dependent-base-classes.cpp -+++ b/clang/test/SemaTemplate/dependent-base-classes.cpp -@@ -1,12 +1,12 @@ - // RUN: %clang_cc1 -fsyntax-only -verify %s - - template --struct X0 : T::template apply { -+struct X0 : T::template apply { - X0(U u) : T::template apply(u) { } - }; - template --struct X1 : T::apply { }; // expected-warning{{use 'template' keyword to treat 'apply' as a dependent template name}} -+struct X1 : T::apply { }; // expected-error{{use 'template' keyword to treat 'apply' as a dependent template name}} - - template - struct X2 : vector { }; // expected-error{{no template named 'vector'}} -@@ -85,7 +85,7 @@ - struct A { }; - - template -- class B : public A -+ class B : public A - { - public: - template< class X > -@@ -109,9 +109,9 @@ - - namespace PR6413 { - template class Base_A { }; -- -+ - class Base_B { }; -- -+ - template - class Derived - : public virtual Base_A -@@ -120,12 +120,12 @@ + template + inline CmpClass_match + m_ICmp(const LHS &L, const RHS &R) { +- return CmpClass_match(L, R); ++ ICmpInst::Predicate Unused; ++ return CmpClass_match(Unused, L, R); } - namespace PR5812 { -- template struct Base { -- Base* p; -- }; -+ template struct Base { -+ Base* p; -+ }; - -- template struct Derived: public Base { -- typename Derived::Base* p; // meaning Derived::Base -+ template struct Derived: public Base { -+ typename Derived::Base* p; // meaning Derived::Base - }; - - Derived di; -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/dependent-template-recover.cpp b/clang/test/SemaTemplate/dependent-template-recover.cpp ---- a/clang/test/SemaTemplate/dependent-template-recover.cpp -+++ b/clang/test/SemaTemplate/dependent-template-recover.cpp -@@ -2,15 +2,15 @@ - template - struct X { - void f(T* t) { -- t->f0(); // expected-warning{{use 'template' keyword to treat 'f0' as a dependent template name}} -- t->f0(); // expected-warning{{use 'template' keyword to treat 'f0' as a dependent template name}} -+ t->f0(); // expected-error{{use 'template' keyword to treat 'f0' as a dependent template name}} -+ t->f0(); // expected-error{{use 'template' keyword to treat 'f0' as a dependent template name}} - -- t->operator+(1); // expected-warning{{use 'template' keyword to treat 'operator +' as a dependent template name}} -- t->f1(1); // expected-warning{{use 'template' keyword to treat 'f1' as a dependent template name}} -+ t->operator+(1); // expected-error{{use 'template' keyword to treat 'operator +' as a dependent template name}} -+ t->f1(1); // expected-error{{use 'template' keyword to treat 'f1' as a dependent template name}} - t->f1<3, int const>(1); // expected-error{{missing 'template' keyword prior to dependent template name 'f1'}} - -- T::getAs(); // expected-warning{{use 'template' keyword to treat 'getAs' as a dependent template name}} -- t->T::getAs(); // expected-warning{{use 'template' keyword to treat 'getAs' as a dependent template name}} -+ T::getAs(); // expected-error{{use 'template' keyword to treat 'getAs' as a dependent template name}} -+ t->T::getAs(); // expected-error{{use 'template' keyword to treat 'getAs' as a dependent template name}} - - (*t).f2(); // expected-error{{missing 'template' keyword prior to dependent template name 'f2'}} - (*t).f2<0>(); // expected-error{{missing 'template' keyword prior to dependent template name 'f2'}} -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/temp_arg_nontype_cxx20.cpp b/clang/test/SemaTemplate/temp_arg_nontype_cxx20.cpp ---- a/clang/test/SemaTemplate/temp_arg_nontype_cxx20.cpp -+++ b/clang/test/SemaTemplate/temp_arg_nontype_cxx20.cpp -@@ -115,7 +115,7 @@ - static_assert(f(X()) == 0); - - template struct Y { void f(); }; -- template void g(Y y) { y.template Y::f(); } -+ template void g(Y y) { y.Y::f(); } - void h() { constexpr A a; g(Y{}); } - - template struct Z { -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/template-id-expr.cpp b/clang/test/SemaTemplate/template-id-expr.cpp ---- a/clang/test/SemaTemplate/template-id-expr.cpp -+++ b/clang/test/SemaTemplate/template-id-expr.cpp -@@ -19,7 +19,7 @@ - struct X0 { - template - void f1(); -- -+ - template - void f2(U) { - f1(); -@@ -39,9 +39,9 @@ - template - struct X { - X(int, int); -- void f() { -- Y >(X(0, 0)); -- Y >(::X(0, 0)); -+ void f() { -+ Y >(X(0, 0)); -+ Y >(::X(0, 0)); - } - }; - -@@ -149,11 +149,11 @@ - - int x; - x = Y1::f4(0); -- x = Y1::f4(0); // expected-warning {{use 'template'}} expected-error {{assigning to 'int' from incompatible type 'void'}} -+ x = Y1::f4(0); // expected-error {{use 'template'}} expected-error {{assigning to 'int' from incompatible type 'void'}} - x = Y1::template f4(0); // expected-error {{assigning to 'int' from incompatible type 'void'}} expected-error {{a template argument list is expected after a name prefixed by the template keyword}} - - x = p->f4(0); -- x = p->f4(0); // expected-error {{assigning to 'int' from incompatible type 'void'}} expected-warning {{use 'template'}} -+ x = p->f4(0); // expected-error {{assigning to 'int' from incompatible type 'void'}} expected-error {{use 'template'}} - x = p->template f4(0); // expected-error {{assigning to 'int' from incompatible type 'void'}} expected-error {{a template argument list is expected after a name prefixed by the template keyword}} - } - }; -@@ -184,7 +184,7 @@ - #if __cplusplus <= 199711L - // expected-warning@+2 {{extension}} - #endif --template using D = int; // expected-note {{declared here}} -+template using D = int; // expected-note {{declared here}} - E ed; // expected-note {{instantiation of}} - - namespace non_functions { -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/typename-specifier-3.cpp b/clang/test/SemaTemplate/typename-specifier-3.cpp ---- a/clang/test/SemaTemplate/typename-specifier-3.cpp -+++ b/clang/test/SemaTemplate/typename-specifier-3.cpp -@@ -46,7 +46,7 @@ - typedef int arg; - }; - struct C { -- typedef typename B::X x; // expected-warning {{use 'template'}} expected-error {{refers to non-type}} -+ typedef typename B::X x; // expected-error {{use 'template'}} expected-error {{refers to non-type}} - }; - }; - -diff -ruN --strip-trailing-cr a/libc/src/__support/macros/config.h b/libc/src/__support/macros/config.h ---- a/libc/src/__support/macros/config.h -+++ b/libc/src/__support/macros/config.h -@@ -15,7 +15,6 @@ - - // Workaround for compilers that do not support builtin detection. - // FIXME: This is only required for the GPU portion which should be moved. --#include "src/__support/macros/config.h" - #ifndef __has_builtin - #define __has_builtin(b) 0 - #endif -diff -ruN --strip-trailing-cr a/libcxx/include/regex b/libcxx/include/regex ---- a/libcxx/include/regex -+++ b/libcxx/include/regex -@@ -4214,7 +4214,7 @@ - _LIBCPP_HIDE_FROM_ABI int compare(const value_type* __s) const { return str().compare(__s); } - - _LIBCPP_HIDE_FROM_ABI void swap(sub_match& __s) _NOEXCEPT_(__is_nothrow_swappable_v<_BidirectionalIterator>) { -- this->template pair<_BidirectionalIterator, _BidirectionalIterator>::swap(__s); -+ this->pair<_BidirectionalIterator, _BidirectionalIterator>::swap(__s); - std::swap(matched, __s.matched); - } - }; -diff -ruN --strip-trailing-cr a/llvm/include/llvm/ADT/ArrayRef.h b/llvm/include/llvm/ADT/ArrayRef.h ---- a/llvm/include/llvm/ADT/ArrayRef.h -+++ b/llvm/include/llvm/ADT/ArrayRef.h -@@ -460,11 +460,8 @@ - - OwningArrayRef &operator=(OwningArrayRef &&Other) { - delete[] this->data(); -- using Base = MutableArrayRef; -- // GCC versions prior to 11.1 incorrectly reject if the 'template' keyword -- // is used prior to the nested-name-specifier here. -- this->Base::operator=(Other); -- Other.Base::operator=(Base()); -+ this->MutableArrayRef::operator=(Other); -+ Other.MutableArrayRef::operator=(MutableArrayRef()); - return *this; - } + template + inline CmpClass_match + m_FCmp(const LHS &L, const RHS &R) { +- return CmpClass_match(L, R); ++ FCmpInst::Predicate Unused; ++ return CmpClass_match(Unused, L, R); + } + // Same as CmpClass, but instead of saving Pred as out output variable, match a diff -ruN --strip-trailing-cr a/llvm/include/llvm/TargetParser/Triple.h b/llvm/include/llvm/TargetParser/Triple.h --- a/llvm/include/llvm/TargetParser/Triple.h +++ b/llvm/include/llvm/TargetParser/Triple.h @@ -3702,18 +354,6 @@ diff -ruN --strip-trailing-cr a/llvm/include/llvm/TargetParser/Triple.h b/llvm/i amdil, // AMDIL amdil64, // AMDIL with 64-bit pointers hsail, // AMD HSAIL -diff -ruN --strip-trailing-cr a/llvm/lib/CodeGen/MachineSink.cpp b/llvm/lib/CodeGen/MachineSink.cpp ---- a/llvm/lib/CodeGen/MachineSink.cpp -+++ b/llvm/lib/CodeGen/MachineSink.cpp -@@ -961,7 +961,7 @@ - MachineBasicBlock *ToBB, - bool BreakPHIEdge) { - // Avoid breaking back edge. From == To means backedge for single BB cycle. -- if (!SplitEdges || FromBB == ToBB) -+ if (!SplitEdges || FromBB == ToBB || !FromBB->isSuccessor(ToBB)) - return false; - - MachineCycle *FromCycle = CI->getCycle(FromBB); diff -ruN --strip-trailing-cr a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp --- a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp +++ b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp @@ -4008,86 +648,6 @@ diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll ; CHECK-NEXT: ret entry: %ext64 = load i32, ptr %x0 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/X86/MachineSink-Issue98477.ll b/llvm/test/CodeGen/X86/MachineSink-Issue98477.ll ---- a/llvm/test/CodeGen/X86/MachineSink-Issue98477.ll -+++ b/llvm/test/CodeGen/X86/MachineSink-Issue98477.ll -@@ -0,0 +1,76 @@ -+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -+; RUN: llc < %s | FileCheck %s -+ -+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" -+target triple = "x86_64-unknown-linux-gnu" -+ -+define i32 @main(i1 %tobool.not, i32 %0) { -+; CHECK-LABEL: main: -+; CHECK: # %bb.0: # %entry -+; CHECK-NEXT: movl $1, %r8d -+; CHECK-NEXT: testb $1, %dil -+; CHECK-NEXT: jne .LBB0_8 -+; CHECK-NEXT: .LBB0_1: # %j.preheader -+; CHECK-NEXT: xorl %r9d, %r9d -+; CHECK-NEXT: jmp .LBB0_2 -+; CHECK-NEXT: .p2align 4, 0x90 -+; CHECK-NEXT: .LBB0_5: # %if.then4 -+; CHECK-NEXT: # in Loop: Header=BB0_2 Depth=1 -+; CHECK-NEXT: movl $1, %eax -+; CHECK-NEXT: xorl %edx, %edx -+; CHECK-NEXT: divl %r8d -+; CHECK-NEXT: testb $1, %dil -+; CHECK-NEXT: jne .LBB0_6 -+; CHECK-NEXT: .LBB0_2: # %j -+; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -+; CHECK-NEXT: movl $1, %eax -+; CHECK-NEXT: xorl %edx, %edx -+; CHECK-NEXT: idivl %esi -+; CHECK-NEXT: movl %edx, %ecx -+; CHECK-NEXT: testb %r9b, %r9b -+; CHECK-NEXT: jne .LBB0_5 -+; CHECK-NEXT: # %bb.3: # %j -+; CHECK-NEXT: # in Loop: Header=BB0_2 Depth=1 -+; CHECK-NEXT: testl %r9d, %r9d -+; CHECK-NEXT: js .LBB0_5 -+; CHECK-NEXT: # %bb.4: -+; CHECK-NEXT: movl %r9d, %edx -+; CHECK-NEXT: .LBB0_6: # %if.end9 -+; CHECK-NEXT: testl %edx, %edx -+; CHECK-NEXT: jne .LBB0_7 -+; CHECK-NEXT: .LBB0_8: # %if.end13 -+; CHECK-NEXT: xorl %r8d, %r8d -+; CHECK-NEXT: jmp .LBB0_1 -+; CHECK-NEXT: .LBB0_7: # %while.body.lr.ph -+; CHECK-NEXT: movl %ecx, %eax -+; CHECK-NEXT: retq -+entry: -+ br i1 %tobool.not, label %if.end13, label %j.preheader -+ -+ j.preheader: ; preds = %if.end13, %entry -+ %h.0.ph = phi i32 [ 1, %entry ], [ 0, %if.end13 ] -+ br label %j -+ -+ j: ; preds = %if.then4, %j.preheader -+ %1 = phi i32 [ %div2, %if.then4 ], [ 0, %j.preheader ] -+ %rem1 = srem i32 1, %0 -+ %cmp = icmp slt i32 %1, 0 -+ %or.cond = select i1 false, i1 true, i1 %cmp -+ br i1 %or.cond, label %if.then4, label %if.end9 -+ -+ if.then4: ; preds = %j -+ %div2 = sdiv i32 1, 0 -+ %rem5 = srem i32 1, %h.0.ph -+ br i1 %tobool.not, label %if.end9, label %j -+ -+ if.end9: ; preds = %if.then4, %j -+ %2 = phi i32 [ 0, %j ], [ %rem5, %if.then4 ] -+ %tobool10.not = icmp eq i32 %2, 0 -+ br i1 %tobool10.not, label %if.end13, label %while.body.lr.ph -+ -+ while.body.lr.ph: ; preds = %if.end9 -+ ret i32 %rem1 -+ -+ if.end13: ; preds = %if.end9, %entry -+ br label %j.preheader -+} diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll b/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll --- a/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll +++ b/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll @@ -4524,6 +1084,27 @@ diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/X86/interleav +; CHECK: [[LOOP9]] = distinct !{[[LOOP9]], [[META1]], [[META2]]} +; CHECK: [[LOOP10]] = distinct !{[[LOOP10]], [[META1]]} ;. +diff -ruN --strip-trailing-cr a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp +--- a/llvm/unittests/IR/PatternMatch.cpp ++++ b/llvm/unittests/IR/PatternMatch.cpp +@@ -2235,7 +2235,7 @@ + MutableConstTestTypes; + TYPED_TEST_SUITE(MutableConstTest, MutableConstTestTypes, ); + +-TYPED_TEST(MutableConstTest, ICmp) { ++TYPED_TEST(MutableConstTest, /* FIXME: UAR bug */ DISABLED_ICmp) { + auto &IRB = PatternMatchTest::IRB; + + typedef std::tuple_element_t<0, TypeParam> ValueType; +@@ -2319,7 +2319,7 @@ + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + } + +-TYPED_TEST(MutableConstTest, FCmp) { ++TYPED_TEST(MutableConstTest, /* FIXME: UAR bug */ DISABLED_FCmp) { + auto &IRB = PatternMatchTest::IRB; + + typedef std::tuple_element_t<0, TypeParam> ValueType; diff -ruN --strip-trailing-cr a/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn b/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn --- a/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn +++ b/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index b93276bcc4c566..6c8da928bb4d1a 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "5ff3ff33ff930e4ec49da7910612d8a41eb068cb" - LLVM_SHA256 = "15fd6dcf22fdf549831d8d490970f66965988f1116dcc4ac04ab2570d9399aba" + LLVM_COMMIT = "dd7d81ea49bf39e1d69bbb84bd3f31bd95519369" + LLVM_SHA256 = "fbd43ef20f4209b0619e209e48c431f76008917714a8c5336063e1ff51d8d084" tf_http_archive( name = name, From 760da5221e3134d095751133ce4cc0a30a7ddc36 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 18 Jul 2024 23:45:49 -0700 Subject: [PATCH 004/376] Remove affine fuzz test for now. This can't be built right now because the grammar bzl is broken in the version of fuzztest we're using. PiperOrigin-RevId: 653893594 --- build_tools/build.py | 2 - xla/service/gpu/model/fuzztest/BUILD | 34 ------ xla/service/gpu/model/fuzztest/affine_fuzz.g4 | 40 ------ .../fuzztest/affine_simplifier_fuzz_test.cc | 114 ------------------ 4 files changed, 190 deletions(-) delete mode 100644 xla/service/gpu/model/fuzztest/BUILD delete mode 100644 xla/service/gpu/model/fuzztest/affine_fuzz.g4 delete mode 100644 xla/service/gpu/model/fuzztest/affine_simplifier_fuzz_test.cc diff --git a/build_tools/build.py b/build_tools/build.py index 739722b30f93d8..2d10390af1ca6b 100755 --- a/build_tools/build.py +++ b/build_tools/build.py @@ -247,7 +247,6 @@ def nvidia_gpu_build_with_compute_capability( configs=("warnings", "nonccl", "rbe_linux_cpu"), target_patterns=_XLA_DEFAULT_TARGET_PATTERNS + ( - "-//xla/service/gpu/model/fuzztest/...", "-//xla/service/gpu/fusions/triton:triton_support_test", ), build_tag_filters=cpu_x86_tag_filter, @@ -269,7 +268,6 @@ def nvidia_gpu_build_with_compute_capability( configs=("warnings", "rbe_cross_compile_linux_arm64_xla", "nonccl"), target_patterns=_XLA_DEFAULT_TARGET_PATTERNS + ( - "-//xla/service/gpu/model/fuzztest/...", "-//xla/service/gpu/fusions/triton:triton_support_test", ), options={**_DEFAULT_BAZEL_OPTIONS, "build_tests_only": True}, diff --git a/xla/service/gpu/model/fuzztest/BUILD b/xla/service/gpu/model/fuzztest/BUILD deleted file mode 100644 index 7cd745587a7eed..00000000000000 --- a/xla/service/gpu/model/fuzztest/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -load( - "@com_google_fuzztest//build_defs:cc_fuzztest_grammar_library.bzl", - "cc_fuzztest_grammar_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -cc_fuzztest_grammar_library( - name = "affine_grammar", - srcs = ["affine_fuzz.g4"], - top_level_rule = "affine", -) - -cc_test( - name = "affine_simplifier_fuzz_test", - srcs = ["affine_simplifier_fuzz_test.cc"], - deps = [ - ":affine_grammar", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/gpu/model:indexing_test_utils", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_fuzztest//fuzztest", - "@com_google_googletest//:gtest_main", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], -) diff --git a/xla/service/gpu/model/fuzztest/affine_fuzz.g4 b/xla/service/gpu/model/fuzztest/affine_fuzz.g4 deleted file mode 100644 index 76651df8e90884..00000000000000 --- a/xla/service/gpu/model/fuzztest/affine_fuzz.g4 +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Grammar for generating random affine expressions. NOTE: this is not a -// complete grammar for affine expressions! We do not consider expressions that -// do not occur in the indexing maps used by MLIR codegen: -// - ceildiv -// - non-constant RHS for mod, floordiv and mul -// We also don't consider expressions with more than two dimensions or symbols. -// This gives us up to four variables in total, which should be enough. - -grammar AFFINE_FUZZ; - -affine: '(d0, d1)[s0, s1] -> (' expr ')'; -floordiv: expr ' floordiv ' NONZERO; -mod: expr ' mod ' POSITIVE; -mul: expr ' * ' INTEGER; -sum: expr ' + ' expr; -expr: INTEGER | SYM | DIM | '(' floordiv ')' | '(' sum ')' | '(' mul ')' | '(' mod ')'; - -SYM : 's' [01]; -DIM : 'd' [01]; -ONETONINE : [1-9]; -DIGITS : (DIGIT | DIGIT DIGIT)?; -DIGIT : '0' | ONETONINE; -NONZERO : '-'? ONETONINE DIGITS; -POSITIVE: ONETONINE DIGITS; -INTEGER: NONZERO | '0'; diff --git a/xla/service/gpu/model/fuzztest/affine_simplifier_fuzz_test.cc b/xla/service/gpu/model/fuzztest/affine_simplifier_fuzz_test.cc deleted file mode 100644 index a6e9a7723f9884..00000000000000 --- a/xla/service/gpu/model/fuzztest/affine_simplifier_fuzz_test.cc +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include - -#include -#include "fuzztest/fuzztest.h" -#include "absl/log/check.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/service/gpu/model/affine_map_printer.h" -#include "xla/service/gpu/model/fuzztest/affine_grammar.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/gpu/model/indexing_test_utils.h" - -namespace xla { -namespace gpu { -namespace { - -IndexingMap GetMap(std::string input, int64_t d0_min, int64_t d0_size, - int64_t d1_min, int64_t d1_size, int64_t s0_min, - int64_t s0_size, int64_t s1_min, int64_t s1_size) { - static mlir::MLIRContext* context = new mlir::MLIRContext(); - mlir::AffineMap affine_map = xla::gpu::ParseAffineMap(input, context); - CHECK_EQ(affine_map.getNumResults(), 1); - - // Set the sizes of unused variables to 1. - if (!affine_map.isFunctionOfSymbol(0)) { - s0_size = 1; - } - if (!affine_map.isFunctionOfSymbol(1)) { - s1_size = 1; - } - if (!affine_map.isFunctionOfDim(0)) { - d0_size = 1; - } - if (!affine_map.isFunctionOfDim(1)) { - d1_size = 1; - } - - Interval s0_interval = {s0_min, s0_min + s0_size - 1}; - Interval s1_interval = {s1_min, s1_min + s1_size - 1}; - Interval d0_interval = {d0_min, d0_min + d0_size - 1}; - Interval d1_interval = {d1_min, d1_min + d1_size - 1}; - - return IndexingMap(affine_map, {{d0_interval}, {d1_interval}}, - {{s0_interval}, {s1_interval}}, {}); -} - -void TestCorrectness(std::string input, int64_t d0_min, int64_t d0_size, - int64_t d1_min, int64_t d1_size, int64_t s0_min, - int64_t s0_size, int64_t s1_min, int64_t s1_size) { - // Verifies that the simplified map produces the same results as the original - // map at every point in its domain. - IndexingMap map = GetMap(input, d0_min, d0_size, d1_min, d1_size, s0_min, - s0_size, s1_min, s1_size); - IndexingMap map_simplified = map; - map_simplified.Simplify(); - - mlir::AffineExpr original = map.GetAffineMap().getResult(0); - mlir::AffineExpr simplified = map_simplified.GetAffineMap().getResult(0); - - EXPECT_OK(VerifyExprsAreIdentical( - original, simplified, map.GetDimensionBounds(), map.GetSymbolBounds())) - << "original: " << AffineMapPrinter().ToString(original) - << ", simplified: " << AffineMapPrinter().ToString(simplified); -} - -void TestIdempotency(std::string input, int64_t d0_min, int64_t d0_size, - int64_t d1_min, int64_t d1_size, int64_t s0_min, - int64_t s0_size, int64_t s1_min, int64_t s1_size) { - // Verifies that Simplify(Simplify(map)) == Simplify(map). - IndexingMap map = GetMap(input, d0_min, d0_size, d1_min, d1_size, s0_min, - s0_size, s1_min, s1_size); - if (map.Simplify()) { - auto before_simplification = map.GetAffineMap(); - EXPECT_FALSE(map.Simplify()); - EXPECT_EQ(before_simplification, map.GetAffineMap()) - << AffineMapPrinter().ToString(before_simplification); - } -} - -auto AffineDomain() { - // The ranges are chosen to include entirely negative, entirely positive and - // mixed domains (but mostly positive ones). - return fuzztest::TupleOf( - fuzztest::InAffineGrammar(), fuzztest::InRange(-10, 100), - fuzztest::InRange(0, 10), fuzztest::InRange(-10, 100), - fuzztest::InRange(0, 10), fuzztest::InRange(-10, 100), - fuzztest::InRange(0, 10), fuzztest::InRange(-10, 100), - fuzztest::InRange(0, 10)); -} - -FUZZ_TEST(AffineSimplifierFuzzTest, TestCorrectness) - .WithDomains(AffineDomain()); -FUZZ_TEST(AffineSimplifierFuzzTest, TestIdempotency) - .WithDomains(AffineDomain()); - -} // namespace -} // namespace gpu -} // namespace xla From 5744ac5eefd69fdb871f9d7fe5af774f02d97973 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Fri, 19 Jul 2024 01:35:32 -0700 Subject: [PATCH 005/376] Reverts 6695d15d8e4b58b84174317c9ebcadc4f701cc5b PiperOrigin-RevId: 653920956 --- xla/service/gpu/BUILD | 47 +++ xla/service/gpu/gpu_compiler.cc | 5 +- xla/service/gpu/gpu_compiler.h | 2 +- xla/service/gpu/nvptx_compiler.cc | 131 ++++--- xla/service/gpu/nvptx_compiler.h | 17 +- xla/service/gpu/ptx_compilation_test.cc | 330 ++++++++++++++++++ xla/stream_executor/cuda/BUILD | 13 + xla/stream_executor/cuda/cuda_asm_compiler.cc | 38 +- xla/stream_executor/cuda/cuda_asm_compiler.h | 5 +- .../cuda/ptx_compilation_method.h | 49 +++ xla/stream_executor/cuda/ptx_linking_method.h | 53 +++ 11 files changed, 616 insertions(+), 74 deletions(-) create mode 100644 xla/service/gpu/ptx_compilation_test.cc create mode 100644 xla/stream_executor/cuda/ptx_compilation_method.h create mode 100644 xla/stream_executor/cuda/ptx_linking_method.h diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 094661167f6a10..4be7457e43992a 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3477,8 +3477,10 @@ cc_library( "//xla/stream_executor/cuda:cuda_asm_compiler", "//xla/stream_executor/cuda:cuda_diagnostics", "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/stream_executor/cuda:ptx_compilation_method", "//xla/stream_executor/cuda:ptx_compiler", "//xla/stream_executor/cuda:ptx_compiler_support", + "//xla/stream_executor/cuda:ptx_linking_method", "//xla/stream_executor/gpu:gpu_asm_opts", "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", @@ -3497,6 +3499,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", @@ -3551,6 +3554,50 @@ xla_test( ], ) +xla_test( + name = "ptx_compilation_test", + srcs = [ + "ptx_compilation_test.cc", + ], + backends = [ + "gpu", + ], + tags = [ + "gpu", + "no_rocm", + "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. + ], + deps = [ + ":gpu_executable", + ":nvptx_compiler_impl", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/stream_executor:device_description", + "//xla/stream_executor/cuda:ptx_compilation_method", + "//xla/stream_executor/cuda:ptx_compiler_support", + "//xla/stream_executor/cuda:ptx_linking_method", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Object", + "@llvm-project//llvm:Support", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + xla_cc_test( name = "gpu_aot_compilation_test", srcs = if_gpu_is_configured([ diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 5a9531b4748c12..d0ba0777e8776c 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -2030,8 +2030,9 @@ absl::StatusOr GpuCompiler::CompileAndLink( } } - auto maybe_backend_result = LinkModules( - stream_exec, std::move(binaries_to_link), module_config.debug_options()); + auto maybe_backend_result = + LinkModules(gpu_version, stream_exec, std::move(binaries_to_link), + module_config.debug_options()); if (!maybe_backend_result.ok()) { LOG(ERROR) << "The CUDA linking API did not work. Please use XLA_FLAGS=" "--xla_gpu_enable_llvm_module_compilation_parallelism=false " diff --git a/xla/service/gpu/gpu_compiler.h b/xla/service/gpu/gpu_compiler.h index 580636d878fa9e..27a434f5a5d035 100644 --- a/xla/service/gpu/gpu_compiler.h +++ b/xla/service/gpu/gpu_compiler.h @@ -235,7 +235,7 @@ class GpuCompiler : public LLVMCompiler { absl::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); virtual absl::StatusOr> LinkModules( - se::StreamExecutor* stream_exec, + se::GpuComputeCapability cc, se::StreamExecutor* stream_exec, std::vector> modules, const DebugOptions& debug_options) { return Unimplemented("LinkModules is not implemented."); diff --git a/xla/service/gpu/nvptx_compiler.cc b/xla/service/gpu/nvptx_compiler.cc index 1da1da11c21841..a962da924d6a54 100644 --- a/xla/service/gpu/nvptx_compiler.cc +++ b/xla/service/gpu/nvptx_compiler.cc @@ -34,8 +34,10 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "third_party/gpus/cuda/include/cuda.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/SourceMgr.h" @@ -92,8 +94,10 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_asm_compiler.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/cuda/ptx_compilation_method.h" #include "xla/stream_executor/cuda/ptx_compiler.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/stream_executor/cuda/ptx_linking_method.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" @@ -575,6 +579,50 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config, return BackendCompileResult{std::move(ptx), std::move(maybe_cubin.value())}; } +using stream_executor::PtxCompilationMethod; + +// Returns the supported compilation methods in the order of priority. +std::vector GetSupportedCompilationMethods() { + std::vector methods; + if (se::IsLibNvPtxCompilerSupported()) { + methods.emplace_back(PtxCompilationMethod::kNvPtxCompiler); + } + methods.emplace_back(PtxCompilationMethod::kPtxas); + return methods; +} + +absl::StatusOr ChooseCompilationMethod( + absl::Span available_compilation_methods, + const DebugOptions& debug_options, bool relocatable) { + std::vector compilation_methods( + available_compilation_methods.begin(), + available_compilation_methods.end()); + VLOG(2) << "Available compilation methods: " + << absl::StrJoin(compilation_methods, ", "); + + auto remove_compilation_method = [&](PtxCompilationMethod method) { + auto it = absl::c_find(compilation_methods, method); + if (it != compilation_methods.end()) { + compilation_methods.erase(it); + } + }; + + if (!debug_options.xla_gpu_enable_libnvptxcompiler()) { + VLOG(3) << "Discarding NvPtxCompiler since it is disabled."; + remove_compilation_method(PtxCompilationMethod::kNvPtxCompiler); + } + + VLOG(2) << "Considered compilation methods: " + << absl::StrJoin(compilation_methods, ", "); + + if (compilation_methods.empty()) { + return absl::UnavailableError( + "No supported compilation method is available."); + } + + return compilation_methods.front(); +} + static absl::StatusOr> AssembleOptionsAndCompile( const std::string& ptx, se::CudaComputeCapability cc, const HloModuleConfig& hlo_module_config, @@ -595,15 +643,24 @@ static absl::StatusOr> AssembleOptionsAndCompile( .xla_gpu_filter_kernels_spilling_registers_on_autotuning() && options.is_autotuning_compilation; + std::vector supported_compilation_methods = + GetSupportedCompilationMethods(); + TF_ASSIGN_OR_RETURN( + PtxCompilationMethod compilation_method, + ChooseCompilationMethod(supported_compilation_methods, + hlo_module_config.debug_options(), relocatable)); + + VLOG(2) << "Using compilation method: " << compilation_method; + absl::StatusOr> maybe_cubin = [&] { - if (hlo_module_config.debug_options().xla_gpu_enable_libnvptxcompiler() && - se::IsLibNvPtxCompilerSupported()) { - return se::CompileGpuAsmUsingLibNvPtxCompiler( - cc.major, cc.minor, ptx.c_str(), ptxas_config, cancel_if_reg_spill); + switch (compilation_method) { + case PtxCompilationMethod::kNvPtxCompiler: + return se::CompileGpuAsmUsingLibNvPtxCompiler( + cc.major, cc.minor, ptx.c_str(), ptxas_config, cancel_if_reg_spill); + case PtxCompilationMethod::kPtxas: + return se::CompileGpuAsmUsingPtxAs(cc.major, cc.minor, ptx.c_str(), + ptxas_config, cancel_if_reg_spill); } - - return se::CompileGpuAsmUsingPtxAs(cc.major, cc.minor, ptx.c_str(), - ptxas_config, cancel_if_reg_spill); }(); if (maybe_cubin.ok()) { @@ -751,9 +808,12 @@ static absl::StatusOr GetAsmCompilerVersion( return se::GetAsmCompilerVersion(preferred_cuda_dir); } -absl::StatusOr ChooseLinkingMethodImpl( - const DebugOptions& debug_options, const std::string& preferred_cuda_dir) { - using LinkingMethod = NVPTXCompiler::LinkingMethod; +absl::StatusOr NVPTXCompiler::ChooseLinkingMethod( + const DebugOptions& debug_options) { + se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options); + std::string& preferred_cuda_dir = ptxas_config.preferred_cuda_dir; + + using LinkingMethod = se::PtxLinkingMethod; TF_ASSIGN_OR_RETURN(auto asm_compiler_version, GetAsmCompilerVersion(debug_options, preferred_cuda_dir)); @@ -785,47 +845,28 @@ absl::StatusOr ChooseLinkingMethodImpl( "You should update your NVIDIA driver or use the NVIDIA-provided " "CUDA forward compatibility packages."; - return LinkingMethod::kNone; -} - -absl::StatusOr NVPTXCompiler::ChooseLinkingMethod( - const DebugOptions& debug_options) { - se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options); - std::string& preferred_cuda_dir = ptxas_config.preferred_cuda_dir; - - { - absl::MutexLock lock(&mutex_); - auto it = linking_methods_.find(preferred_cuda_dir); - if (it != linking_methods_.end()) { - return it->second; - } - } - - // This wrapper only handles caching. The actual choice happens in this call: - TF_ASSIGN_OR_RETURN( - LinkingMethod linking_method, - ChooseLinkingMethodImpl(debug_options, preferred_cuda_dir)); - - { - absl::MutexLock lock(&mutex_); - linking_methods_[preferred_cuda_dir] = linking_method; - } - return linking_method; + return se::PtxLinkingMethod::kNone; } absl::StatusOr NVPTXCompiler::CanUseLinkModules( const HloModuleConfig& hlo_module_config) { // TODO(phawkins): rather than comparing version numbers, it might be more // robust if we simply tried to link something the first time we compile. - TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, + TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method, ChooseLinkingMethod(hlo_module_config.debug_options())); - return linking_method != LinkingMethod::kNone; + return linking_method != se::PtxLinkingMethod::kNone; } absl::StatusOr> NVPTXCompiler::LinkModules( - se::StreamExecutor* stream_exec, std::vector> modules, + se::GpuComputeCapability cc, se::StreamExecutor* stream_exec, + std::vector> modules, const DebugOptions& debug_options) { - auto ptxas_config = PtxOptsFromDebugOptions(debug_options); + if (modules.empty()) return std::vector{}; + + TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method, + ChooseLinkingMethod(debug_options)); + VLOG(1) << "Linking " << modules.size() + << " modules with linking method: " << linking_method; std::vector images; images.reserve(modules.size()); @@ -833,14 +874,12 @@ absl::StatusOr> NVPTXCompiler::LinkModules( images.push_back({"", std::move(module)}); } auto context = se::gpu::ExtractGpuExecutor(stream_exec)->gpu_context(); - - TF_ASSIGN_OR_RETURN(LinkingMethod linking_method, - ChooseLinkingMethod(debug_options)); - if (linking_method == LinkingMethod::kNvLink) { - return LinkUsingNvlink(debug_options.xla_gpu_cuda_data_dir(), context, + if (linking_method == se::PtxLinkingMethod::kNvLink) { + return LinkUsingNvlink(std::get(cc), + debug_options.xla_gpu_cuda_data_dir(), context, images); } - return LinkGpuAsm(context, images); + return LinkGpuAsm(std::get(cc), context, images); } } // namespace gpu diff --git a/xla/service/gpu/nvptx_compiler.h b/xla/service/gpu/nvptx_compiler.h index 45a138c1442768..25fa268226107d 100644 --- a/xla/service/gpu/nvptx_compiler.h +++ b/xla/service/gpu/nvptx_compiler.h @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_pipeline.h" +#include "xla/stream_executor/cuda/ptx_linking_method.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" @@ -94,27 +95,16 @@ class NVPTXCompiler : public GpuCompiler { se::GpuComputeCapability gpu_version, bool relocatable, const HloModule* debug_module, const CompileOptions& options) override; - enum class LinkingMethod { - kNone, - kNvLink, - kDriver, - }; - absl::StatusOr CanUseLinkModules( const HloModuleConfig& module_config) override; private: absl::StatusOr> LinkModules( - se::StreamExecutor* stream_exec, + se::GpuComputeCapability cc, se::StreamExecutor* stream_exec, std::vector> modules, const DebugOptions& debug_options) override; - absl::Mutex mutex_; - - absl::flat_hash_map linking_methods_ - ABSL_GUARDED_BY(mutex_); - - absl::StatusOr ChooseLinkingMethod( + absl::StatusOr ChooseLinkingMethod( const DebugOptions& debug_options); // Tries to compile the given ptx string to cubin. Returns a vector with the @@ -191,6 +181,7 @@ class NVPTXCompiler : public GpuCompiler { // Don't even think about switching this to flat_hash_map; iterator stability // is critical here. + absl::Mutex mutex_; absl::node_hash_map compilation_cache_ ABSL_GUARDED_BY(mutex_); diff --git a/xla/service/gpu/ptx_compilation_test.cc b/xla/service/gpu/ptx_compilation_test.cc new file mode 100644 index 00000000000000..b8496b36e52f95 --- /dev/null +++ b/xla/service/gpu/ptx_compilation_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Object/ELFObjectFile.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/MemoryBuffer.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/executable.h" +#include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/nvptx_compiler.h" +#include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/cuda/ptx_compilation_method.h" +#include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/stream_executor/cuda/ptx_linking_method.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/path.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +constexpr std::string_view kSimpleHlo = R"( +HloModule simple + +ENTRY main { + p = f32[10]{0} parameter(0) + ROOT neg = f32[10]{0} negate(p) +} +)"; +constexpr std::string_view kParallelCompilationHlo = R"( +HloModule parallel_compilation + +ENTRY main { + p1 = f32[10,20,30] parameter(0) + p2 = f32[40,30,10] parameter(1) + // With the new MLIR emitters, each indexing change leads to a new function. + // So adding 2 transposes and a concatenate will results in 3 LLVM IR + // functions that can be compiled in parallel. + t1 = f32[20,10,30] transpose(p1), dimensions={1,0,2} + t2 = f32[40,10,30] transpose(p2), dimensions={0,2,1} + ROOT c = f32[60,10,30] concatenate(t1, t2), dimensions={0} +} +)"; + +constexpr std::string_view kSM90AHlo = R"( +gemm_fusion_dot { + %p0 = f16[64,1024]{1,0} parameter(0) + %p1 = f16[1024,32,32]{2,1,0} parameter(1) + %bitcast.74246 = f16[1024,1024]{0,1} bitcast(f16[1024,32,32]{2,1,0} %p1) + ROOT %dot.1302 = f16[64,1024]{1,0} dot(f16[64,1024]{1,0} %p0, f16[1024,1024]{0,1} %bitcast.74246), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"} +} + +ENTRY e { + p0 = f16[64,1024]{1,0} parameter(0) + p1 = f16[1024,32,32]{2,1,0} parameter(1) + // This Triton fusion generates a wgmma instruction which allows us to test + // whether we properly enable SM 9.0A in all compilation and linking paths. + ROOT triton_gemm_fusion_dot = f16[64,1024]{1,0} fusion(p0, p1), kind=kCustom, + calls=gemm_fusion_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":64,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + +std::string_view GetHlo(std::string_view name) { + static const absl::flat_hash_map* const + kHloMap = new absl::flat_hash_map( + {{"simple", kSimpleHlo}, + {"parallel_compilation", kParallelCompilationHlo}, + {"requires_sm90a", kSM90AHlo}}); + return kHloMap->at(name); +} + +void DumpArtifactIfEnabled(std::string_view name, + absl::Span data) { + if (std::string output_dir; + tsl::io::GetTestUndeclaredOutputsDir(&output_dir)) { + (void)tsl::WriteStringToFile( + tsl::Env::Default(), tsl::io::JoinPath(output_dir, name), + std::string_view(reinterpret_cast(data.data()), + data.size())); + } +} + +using stream_executor::PtxCompilationMethod; +using stream_executor::PtxLinkingMethod; + +std::string GenerateParametrizedTestname( + std::string_view name, PtxCompilationMethod compilation_method, + PtxLinkingMethod linking_method) { + return absl::StrFormat("%v_CompilationMethod_%v_LinkingMethod_%v", name, + compilation_method, linking_method); +} + +class NVPTXCompilationTests + : public HloTestBase, + public ::testing::WithParamInterface> { + public: + void SkipTestIfUnsupported(std::string_view name, + PtxCompilationMethod compilation_method, + PtxLinkingMethod linking_method) { + using CudaComputeCapability = stream_executor::CudaComputeCapability; + if (!::testing::Value(backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(), + ::testing::VariantWith( + CudaComputeCapability{9, 0})) && + name == "requires_sm90a") { + GTEST_SKIP() << "This test requires SM 9.0a"; + } + + if (!stream_executor::IsLibNvPtxCompilerSupported() && + compilation_method == PtxCompilationMethod::kNvPtxCompiler) { + // Compiled without libnvptxcompiler support + GTEST_SKIP() << "libnvptxcompiler is not supported in this build."; + } + } + + void SetDebugOptionsFromPtxSettings(DebugOptions* debug_options, + PtxCompilationMethod compilation_method, + PtxLinkingMethod linking_method) { + debug_options->set_xla_gpu_enable_libnvptxcompiler( + compilation_method == PtxCompilationMethod::kNvPtxCompiler); + + debug_options->set_xla_gpu_enable_llvm_module_compilation_parallelism( + linking_method != PtxLinkingMethod::kNone); + debug_options->set_xla_gpu_force_compilation_parallelism(12); + + if (linking_method == PtxLinkingMethod::kDriver) { + debug_options->set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found( + true); + debug_options->set_xla_gpu_cuda_data_dir("/does/not/exist"); + } + + tsl::setenv("TF_USE_NVLINK_FOR_PARALLEL_COMPILATION", + linking_method == PtxLinkingMethod::kNvLink ? "true" : "false", + 1); + + // We need individual functions to test parallel compilation. + debug_options->set_xla_llvm_force_inline_before_split(false); + } + + void SetUp() override { + HloTestBase::SetUp(); + std::string_view name = std::get<0>(GetParam()); + PtxCompilationMethod compilation_method = std::get<1>(GetParam()); + PtxLinkingMethod linking_method = std::get<2>(GetParam()); + SkipTestIfUnsupported(name, compilation_method, linking_method); + } + + absl::StatusOr> CompileExecutable( + std::unique_ptr module) { + NVPTXCompiler compiler{}; + + return compiler.RunBackend(std::move(module), + backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*is_autotuning_compilation=*/false}); + } +}; + +TEST_P(NVPTXCompilationTests, CompileProgram) { + std::string_view name = std::get<0>(GetParam()); + std::string_view hlo_text = GetHlo(name); + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + + HloModuleConfig hlo_module_config = module->config(); + DebugOptions debug_options = hlo_module_config.debug_options(); + PtxCompilationMethod compilation_method = std::get<1>(GetParam()); + PtxLinkingMethod linking_method = std::get<2>(GetParam()); + SetDebugOptionsFromPtxSettings(&debug_options, compilation_method, + linking_method); + hlo_module_config.set_debug_options(debug_options); + module->set_config(hlo_module_config); + + EXPECT_THAT(CompileExecutable(std::move(module)), + tsl::testing::IsOkAndHolds(::testing::NotNull())); +} + +TEST_P(NVPTXCompilationTests, CompareBinaryOutput) { + std::string_view name = std::get<0>(GetParam()); + std::string_view hlo_text = GetHlo(name); + auto compile = [&](PtxCompilationMethod compilation_method, + PtxLinkingMethod linking_method) { + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + + HloModuleConfig hlo_module_config = module->config(); + DebugOptions debug_options = hlo_module_config.debug_options(); + SetDebugOptionsFromPtxSettings(&debug_options, compilation_method, + linking_method); + hlo_module_config.set_debug_options(debug_options); + module->set_config(hlo_module_config); + + return CompileExecutable(std::move(module)); + }; + + PtxCompilationMethod compilation_method = std::get<1>(GetParam()); + PtxLinkingMethod linking_method = std::get<2>(GetParam()); + absl::StatusOr> executable = + compile(compilation_method, linking_method); + + // Non parallel compilation (PtxLinkingMethod::kNone) generates slightly + // different code (different register assignment, different instruction + // ordering). Ideally we would do a fuzzy match, but for now let's just not + // compare between parallel and non-parallel compilation. + const PtxLinkingMethod reference_linking_method = + linking_method == PtxLinkingMethod::kNone ? PtxLinkingMethod::kNone + : PtxLinkingMethod::kNvLink; + absl::StatusOr> reference = + compile(PtxCompilationMethod::kPtxas, reference_linking_method); + + EXPECT_THAT(executable, tsl::testing::IsOkAndHolds(::testing::NotNull())); + EXPECT_THAT(reference, tsl::testing::IsOkAndHolds(::testing::NotNull())); + + absl::Span executable_binary = + static_cast(executable.value().get())->binary(); + absl::Span reference_binary = + static_cast(reference.value().get())->binary(); + + if (executable_binary != reference_binary) { + std::string test_name = + GenerateParametrizedTestname(name, compilation_method, linking_method); + DumpArtifactIfEnabled(absl::StrCat(test_name, "_executable.bin"), + executable_binary); + DumpArtifactIfEnabled(absl::StrCat(test_name, "_reference.bin"), + reference_binary); + } + + auto get_text_sections = [&](absl::Span binary) + -> absl::StatusOr> { + auto buffer = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(reinterpret_cast(binary.data()), + binary.size()), + /*BufferName=*/"", /*RequiresNullTerminator=*/false); + auto object_file = + llvm::object::ObjectFile::createObjectFile(buffer->getMemBufferRef()); + + if (!object_file) { + return absl::InternalError(llvm::toString(object_file.takeError())); + } + + auto executable_elf_object_file = + llvm::dyn_cast( + object_file.get().get()); + + if (!executable_elf_object_file) { + return absl::InternalError( + "Generated executable binary is not a 64bit ELF file."); + } + + absl::btree_map text_sections; + + for (const auto& section : executable_elf_object_file->sections()) { + if (absl::StartsWith(section.getName().get().str(), ".text")) { + text_sections[section.getName().get().str()] = + section.getContents().get().str(); + } + } + + return text_sections; + }; + + TF_ASSERT_OK_AND_ASSIGN(auto executable_text_sections, + get_text_sections(executable_binary)); + TF_ASSERT_OK_AND_ASSIGN(auto reference_text_sections, + get_text_sections(reference_binary)); + + EXPECT_THAT(executable_text_sections, ::testing::Eq(reference_text_sections)); +} + +INSTANTIATE_TEST_SUITE_P( + NVPTXCompilationTest, NVPTXCompilationTests, + ::testing::Combine( + ::testing::Values("simple", "parallel_compilation", "requires_sm90a"), + ::testing::Values(PtxCompilationMethod::kNvPtxCompiler, + PtxCompilationMethod::kPtxas), + ::testing::Values(PtxLinkingMethod::kNone, PtxLinkingMethod::kNvLink, + PtxLinkingMethod::kDriver)), + [](const ::testing::TestParamInfo>& info) { + return GenerateParametrizedTestname(std::get<0>(info.param), + std::get<1>(info.param), + std::get<2>(info.param)); + }); + +} // namespace +} // namespace xla::gpu diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index c5bb43d587b584..917785c848ecc4 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -714,6 +714,7 @@ cuda_only_cc_library( ":ptx_compiler_support", "//xla:status_macros", "//xla:util", + "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:gpu_asm_opts", "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_types_header", @@ -852,3 +853,15 @@ cc_library( srcs = ["cudnn_frontend_helpers.cc"], hdrs = ["cudnn_frontend_helpers.h"], ) + +cc_library( + name = "ptx_compilation_method", + hdrs = ["ptx_compilation_method.h"], + deps = ["@com_google_absl//absl/strings"], +) + +cc_library( + name = "ptx_linking_method", + hdrs = ["ptx_linking_method.h"], + deps = ["@com_google_absl//absl/strings"], +) diff --git a/xla/stream_executor/cuda/cuda_asm_compiler.cc b/xla/stream_executor/cuda/cuda_asm_compiler.cc index 489fedd9142c6a..caea22789a151d 100644 --- a/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -50,6 +50,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/cuda/ptx_compiler.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -528,6 +529,7 @@ absl::StatusOr GetNvLinkVersion( } absl::StatusOr> LinkUsingNvlink( + stream_executor::CudaComputeCapability cc, std::string_view preferred_cuda_dir, gpu::GpuContext* context, std::vector images) { LOG_FIRST_N(INFO, 1) << "Using nvlink for parallel linking"; @@ -562,17 +564,10 @@ absl::StatusOr> LinkUsingNvlink( // produce TF error. tsl::Env::Default()->DeleteFile(output_path).IgnoreError(); }; - int cc_major; - int cc_minor; - { - TF_ASSIGN_OR_RETURN(auto cu_device, - gpu::GpuDriver::DeviceFromContext(context)); - TF_RETURN_IF_ERROR( - gpu::GpuDriver::GetComputeCapability(&cc_major, &cc_minor, cu_device)); - } std::vector args; args.push_back(bin_path); - args.push_back(absl::StrCat("-arch=sm_", cc_major, cc_minor)); + std::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : ""; + args.push_back(absl::StrCat("-arch=sm_", cc.major, cc.minor, extension)); for (int i = 0; i < images.size(); i++) { args.push_back(temp_files[i]); } @@ -611,11 +606,32 @@ absl::StatusOr> LinkUsingNvlink( } absl::StatusOr> LinkGpuAsm( - gpu::GpuContext* context, std::vector images) { + stream_executor::CudaComputeCapability cc, gpu::GpuContext* context, + std::vector images) { gpu::ScopedActivateContext activation(context); CUlinkState link_state; - RETURN_IF_CUDA_ERROR(cuLinkCreate(0, nullptr, nullptr, &link_state)); + CUjit_option options[] = {CU_JIT_TARGET}; + CUjit_target target = static_cast(cc.major * 10 + cc.minor); +#if CUDA_VERSION >= 12000 + // Even though CUDA 11.8 has Hopper support, SM 9.0a and most Hopper features + // (WGMMA, TMA, and more) are only supported in CUDA 12+. + if (cc.major == 9 && cc.minor == 0) { + target = + static_cast(target + CU_COMPUTE_ACCELERATED_TARGET_BASE); + } +#endif + void* option_values[] = { + // We first cast to an integer type the same size as a pointer, and then + // we reinterpret that integer as a pointer. + reinterpret_cast(static_cast(target))}; + + // Both arrays must have the same number of elements. + static_assert(sizeof(options) / sizeof(options[0]) == + sizeof(option_values) / sizeof(option_values[0])); + + RETURN_IF_CUDA_ERROR(cuLinkCreate(sizeof(options) / sizeof(options[0]), + options, option_values, &link_state)); for (auto& image : images) { auto status = cuLinkAddData(link_state, CU_JIT_INPUT_CUBIN, static_cast(image.bytes.data()), diff --git a/xla/stream_executor/cuda/cuda_asm_compiler.h b/xla/stream_executor/cuda/cuda_asm_compiler.h index d906f927f99bea..1669d04231295a 100644 --- a/xla/stream_executor/cuda/cuda_asm_compiler.h +++ b/xla/stream_executor/cuda/cuda_asm_compiler.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/stream_executor/cuda/cuda_driver.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" namespace stream_executor { @@ -72,9 +73,11 @@ absl::StatusOr> BundleGpuAsm( // Links multiple relocatable GPU images (e.g. results of ptxas -c) into a // single image. absl::StatusOr> LinkGpuAsm( - gpu::GpuContext* context, std::vector images); + stream_executor::CudaComputeCapability cc, gpu::GpuContext* context, + std::vector images); absl::StatusOr> LinkUsingNvlink( + stream_executor::CudaComputeCapability cc, std::string_view preferred_cuda_dir, gpu::GpuContext* context, std::vector images); diff --git a/xla/stream_executor/cuda/ptx_compilation_method.h b/xla/stream_executor/cuda/ptx_compilation_method.h new file mode 100644 index 00000000000000..d6e28e96b69d67 --- /dev/null +++ b/xla/stream_executor/cuda/ptx_compilation_method.h @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILATION_METHOD_H_ +#define XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILATION_METHOD_H_ + +#include + +#include "absl/strings/str_cat.h" +namespace stream_executor { + +enum class PtxCompilationMethod { + kNvPtxCompiler, + kPtxas, +}; + +template +static void AbslStringify(Sink& sink, + const PtxCompilationMethod& compilation_method) { + switch (compilation_method) { + case PtxCompilationMethod::kNvPtxCompiler: + sink.Append("NvPtxCompiler"); + break; + case PtxCompilationMethod::kPtxas: + sink.Append("Ptxas"); + break; + } +} + +inline std::ostream& operator<<(std::ostream& os, + const PtxCompilationMethod& method) { + return os << absl::StrCat(method); +} + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILATION_METHOD_H_ diff --git a/xla/stream_executor/cuda/ptx_linking_method.h b/xla/stream_executor/cuda/ptx_linking_method.h new file mode 100644 index 00000000000000..56dcdf1fa53d54 --- /dev/null +++ b/xla/stream_executor/cuda/ptx_linking_method.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_CUDA_PTX_LINKING_METHOD_H_ +#define XLA_STREAM_EXECUTOR_CUDA_PTX_LINKING_METHOD_H_ + +#include + +#include "absl/strings/str_cat.h" + +namespace stream_executor { + +enum class PtxLinkingMethod { + kNone, + kNvLink, + kDriver, +}; + +template +void AbslStringify(Sink& sink, const PtxLinkingMethod& method) { + switch (method) { + case PtxLinkingMethod::kNvLink: + sink.Append("NvLink"); + break; + case PtxLinkingMethod::kDriver: + sink.Append("Driver"); + break; + case PtxLinkingMethod::kNone: + sink.Append("None"); + break; + } +} + +inline std::ostream& operator<<(std::ostream& os, + const PtxLinkingMethod& method) { + return os << absl::StrCat(method); +} + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_CUDA_PTX_LINKING_METHOD_H_ From 990138e5919cad96f508382bc9b28f346a1599dc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 19 Jul 2024 01:46:13 -0700 Subject: [PATCH 006/376] Automated Code Change PiperOrigin-RevId: 653923788 --- xla/service/BUILD | 70 ++++++++++++++++++- xla/service/human_readable_profile_builder.cc | 2 + xla/service/indexed_array_analysis.cc | 13 ++++ xla/service/indexed_array_analysis.h | 11 +++ xla/service/indexed_array_analysis_test.cc | 5 +- xla/service/instruction_fusion.h | 6 ++ xla/service/instruction_fusion_test.cc | 8 +++ xla/service/instruction_hoister.cc | 12 ++++ xla/service/instruction_hoister.h | 3 + xla/service/latency_hiding_scheduler.cc | 7 ++ xla/service/latency_hiding_scheduler.h | 12 ++++ xla/service/latency_hiding_scheduler_test.cc | 11 +++ xla/service/layout_assignment.h | 1 + xla/service/layout_assignment_test.cc | 10 ++- xla/service/layout_normalization.cc | 13 +++- xla/service/layout_normalization.h | 2 + xla/service/llvm_compiler.cc | 6 ++ xla/service/llvm_compiler.h | 5 ++ 18 files changed, 191 insertions(+), 6 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 90a77260115595..f154fbf2e98f17 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1396,6 +1396,9 @@ cc_library( ":hlo_buffer", ":hlo_cost_analysis", ":hlo_pass", + ":hlo_value", + "//xla:debug_options_flags", + "//xla:shape_util", "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", @@ -1411,6 +1414,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -1419,11 +1424,21 @@ xla_cc_test( srcs = ["latency_hiding_scheduler_test.cc"], deps = [ ":async_collective_creator", + ":hlo_cost_analysis", ":latency_hiding_scheduler", + "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", ], ) @@ -1719,9 +1734,17 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":compiler", + ":executable", + ":stream_pool", + "//xla:executable_run_options", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Core", "@tsl//tsl/platform:denormal", + "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:scoped_annotation", ], ) @@ -2165,12 +2188,14 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:macros", "@tsl//tsl/platform:status", ] + if_google(["@com_google_absl//absl/types:source_location"]), ) @@ -2181,9 +2206,13 @@ xla_cc_test( deps = [ ":hlo_parser", ":instruction_fusion", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", ], ) @@ -4283,6 +4312,8 @@ cc_library( "//xla:metric_table_report", "//xla:types", "//xla:util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:logging", @@ -4933,6 +4964,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -5327,9 +5359,11 @@ xla_cc_test( ":computation_layout", ":hlo_parser", ":layout_assignment", + ":logical_buffer", ":pattern_matcher", ":pattern_matcher_gmock", "//xla:literal", + "//xla:literal_util", "//xla:shape_layout", "//xla:shape_util", "//xla:test", @@ -5338,11 +5372,15 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", ], ) @@ -6746,14 +6784,24 @@ cc_library( hdrs = ["indexed_array_analysis.h"], deps = [ ":hlo_pass", + "//xla:literal", + "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -6762,10 +6810,13 @@ xla_cc_test( srcs = ["indexed_array_analysis_test.cc"], deps = [ ":indexed_array_analysis", + "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", ], ) @@ -7668,15 +7719,22 @@ cc_library( ":hlo_creation_utils", ":hlo_pass", ":shape_inference", + "//xla:literal", "//xla:permutation_util", "//xla:shape_util", + "//xla:status_macros", "//xla:util", - "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -7687,6 +7745,12 @@ cc_library( deps = [ ":hlo_pass", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:status", ], ) diff --git a/xla/service/human_readable_profile_builder.cc b/xla/service/human_readable_profile_builder.cc index 3e46a3e775415f..ff812c58597441 100644 --- a/xla/service/human_readable_profile_builder.cc +++ b/xla/service/human_readable_profile_builder.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/human_readable_profile_builder.h" +#include "absl/algorithm/container.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "xla/metric_table_report.h" diff --git a/xla/service/indexed_array_analysis.cc b/xla/service/indexed_array_analysis.cc index 0211baf8d65654..a9a5004b011ecc 100644 --- a/xla/service/indexed_array_analysis.cc +++ b/xla/service/indexed_array_analysis.cc @@ -25,11 +25,24 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" #include "xla/map_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/indexed_array_analysis.h b/xla/service/indexed_array_analysis.h index ce56c6b287246c..634c24f2068396 100644 --- a/xla/service/indexed_array_analysis.h +++ b/xla/service/indexed_array_analysis.h @@ -19,9 +19,20 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/indexed_array_analysis_test.cc b/xla/service/indexed_array_analysis_test.cc index d711c8f1fdc6bf..7438ac5de0bee0 100644 --- a/xla/service/indexed_array_analysis_test.cc +++ b/xla/service/indexed_array_analysis_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "xla/service/indexed_array_analysis.h" +#include +#include "absl/log/log.h" #include "absl/strings/ascii.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/instruction_fusion.h b/xla/service/instruction_fusion.h index 8ace349f141db4..571a9f998df093 100644 --- a/xla/service/instruction_fusion.h +++ b/xla/service/instruction_fusion.h @@ -23,6 +23,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/service/hlo_module_config.h" +#include "tsl/platform/macros.h" // The source_location.h is not available in open source. #if defined(PLATFORM_GOOGLE) #include "absl/types/source_location.h" diff --git a/xla/service/instruction_fusion_test.cc b/xla/service/instruction_fusion_test.cc index 98ce0c307ff40e..db6c3244c3932f 100644 --- a/xla/service/instruction_fusion_test.cc +++ b/xla/service/instruction_fusion_test.cc @@ -15,9 +15,17 @@ limitations under the License. #include "xla/service/instruction_fusion.h" +#include +#include +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_parser.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/xla/service/instruction_hoister.cc b/xla/service/instruction_hoister.cc index 58e27f3dbf1183..d706a873429b55 100644 --- a/xla/service/instruction_hoister.cc +++ b/xla/service/instruction_hoister.cc @@ -15,6 +15,18 @@ limitations under the License. #include "xla/service/instruction_hoister.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "tsl/platform/status.h" + namespace xla { namespace { diff --git a/xla/service/instruction_hoister.h b/xla/service/instruction_hoister.h index 0f0f1683e314da..e3db5f37d9e87c 100644 --- a/xla/service/instruction_hoister.h +++ b/xla/service/instruction_hoister.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_INSTRUCTION_HOISTER_H_ #define XLA_SERVICE_INSTRUCTION_HOISTER_H_ +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index 33eeece389ba31..8b7b43322e1dc3 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -49,8 +50,14 @@ limitations under the License. #include "xla/service/dump.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/latency_hiding_scheduler.h b/xla/service/latency_hiding_scheduler.h index 64903272d957af..b04c746280833e 100644 --- a/xla/service/latency_hiding_scheduler.h +++ b/xla/service/latency_hiding_scheduler.h @@ -26,14 +26,26 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/map_util.h" #include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/service/hlo_value.h" +#include "xla/shape_util.h" #include "xla/xla.pb.h" namespace xla { diff --git a/xla/service/latency_hiding_scheduler_test.cc b/xla/service/latency_hiding_scheduler_test.cc index ca228544fcbdd3..f73f6470519510 100644 --- a/xla/service/latency_hiding_scheduler_test.cc +++ b/xla/service/latency_hiding_scheduler_test.cc @@ -26,12 +26,23 @@ limitations under the License. #include #include +#include #include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/async_collective_creator.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/layout_assignment.h b/xla/service/layout_assignment.h index 693c80c817bda8..ba12a2a325bc99 100644 --- a/xla/service/layout_assignment.h +++ b/xla/service/layout_assignment.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/container/node_hash_map.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" diff --git a/xla/service/layout_assignment_test.cc b/xla/service/layout_assignment_test.cc index 551e69aa29551a..139124bd6c09bb 100644 --- a/xla/service/layout_assignment_test.cc +++ b/xla/service/layout_assignment_test.cc @@ -20,28 +20,36 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/computation_layout.h" #include "xla/service/hlo_parser.h" +#include "xla/service/logical_buffer.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/layout_normalization.cc b/xla/service/layout_normalization.cc index 65dfaeef4517d6..2dce620c81b267 100644 --- a/xla/service/layout_normalization.cc +++ b/xla/service/layout_normalization.cc @@ -22,20 +22,31 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/permutation_util.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/util.h" -#include "xla/window_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/layout_normalization.h b/xla/service/layout_normalization.h index 770c1657b5732e..b847e41e0e94cc 100644 --- a/xla/service/layout_normalization.h +++ b/xla/service/layout_normalization.h @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" diff --git a/xla/service/llvm_compiler.cc b/xla/service/llvm_compiler.cc index 02afa91c5d9ea5..fac84fbbfff2f3 100644 --- a/xla/service/llvm_compiler.cc +++ b/xla/service/llvm_compiler.cc @@ -20,7 +20,13 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/service/executable.h" +#include "xla/service/stream_pool.h" #include "tsl/platform/denormal.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/scoped_annotation.h" #ifdef __FAST_MATH__ diff --git a/xla/service/llvm_compiler.h b/xla/service/llvm_compiler.h index 4bd0e8d7c9d24f..ceebd48965cd19 100644 --- a/xla/service/llvm_compiler.h +++ b/xla/service/llvm_compiler.h @@ -16,8 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_LLVM_COMPILER_H_ #define XLA_SERVICE_LLVM_COMPILER_H_ +#include "absl/log/check.h" +#include "absl/status/statusor.h" #include "llvm/IR/Module.h" +#include "xla/executable_run_options.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/service/compiler.h" +#include "xla/service/executable.h" namespace xla { From cf312aaccf2870e43082769a9caa7823faa398b2 Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Fri, 19 Jul 2024 02:23:56 -0700 Subject: [PATCH 007/376] PR #13425: [ROCM] gemm precision settings for autotuner Imported from GitHub PR https://github.com/openxla/xla/pull/13425 Here we add a new flag **xla_gpu_autotune_gemm_rtol** which controls the relative precision used by the BufferComparator (defaults to **0.1**). Also I added one more "paranoid" level 5 for **xla_gpu_autotune_level** which forces the autotuner to discard solutions with accuracy problems. Long time I was under impression that the autotuner already does it, however this is not the case as outlined [here](https://github.com/ROCm/xla/blob/6301f04c50c7637a65b3e0c6f40be628aa00947f/xla/service/gpu/stream_executor_util.cc#L640). BufferComparator just prints out the error message but **keeps wrong solutions** as possible candidates which could lead to a great confusion. So, the autotune level 5 is supposed to discard solutions with accuracy problems. Besides, I also did some small refactoring on BufferComparator to simplify the source code and added **verbose** flag in order to mute error messages if needed. @xla-rotation: could you please have a look? Copybara import of the project: -- cab53b672f9546fe4b811d09cf998b105dc8be01 by Pavel Emeliyanenko : added precision settings for autotuner and buffer_comparator small refactoring, added verbose flag Merging this change closes #13425 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/13425 from ROCm:ci_gemm_autotuner_precision_check cab53b672f9546fe4b811d09cf998b105dc8be01 PiperOrigin-RevId: 653935717 --- xla/debug_options_flags.cc | 15 +++- xla/service/gpu/autotuner_util.h | 1 + xla/service/gpu/buffer_comparator.cc | 13 ++-- xla/service/gpu/buffer_comparator.h | 8 +-- xla/service/gpu/buffer_comparator_test.cc | 5 +- xla/service/gpu/conv_algorithm_picker.cc | 7 +- xla/service/gpu/gemm_algorithm_picker.cc | 68 +++++++++++++------ xla/service/gpu/gemm_algorithm_picker.h | 5 ++ xla/service/gpu/gemm_algorithm_picker_test.cc | 62 +++++++++++++++++ xla/service/gpu/gemm_fusion_autotuner.cc | 2 +- .../gpu/triton_fusion_numerics_verifier.cc | 3 +- xla/xla.proto | 5 +- 12 files changed, 156 insertions(+), 38 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index d49eb4b7b110e2..05a8c216e613bb 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -279,6 +279,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_per_fusion_autotune_cache_dir(""); + opts.set_xla_gpu_autotune_gemm_rtol(0.1f); + return opts; } @@ -835,13 +837,24 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), debug_options->xla_gpu_autotune_level(), "Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = " - "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check.")); + "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check; " + "5 = on+init+reinit+check and skip WRONG_RESULT solutions. See also " + "the related flag xla_gpu_autotune_gemm_rtol. Remark that, setting the " + "level to 5 only makes sense if you are sure that the reference (first " + "in the list) solution is numerically CORRECT. Otherwise, the autotuner " + "might discard many other correct solutions based on the failed " + "BufferComparator test.")); flag_list->push_back(tsl::Flag( "xla_gpu_autotune_max_solutions", int64_setter_for(&DebugOptions::set_xla_gpu_autotune_max_solutions), debug_options->xla_gpu_autotune_max_solutions(), "Maximal number of GEMM solutions to consider for autotuning: 0 means " "consider all solutions returned by the GEMM library.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_autotune_gemm_rtol", + float_setter_for(&DebugOptions::set_xla_gpu_autotune_gemm_rtol), + debug_options->xla_gpu_autotune_gemm_rtol(), + "Relative precision for comparing GEMM solutions vs the reference one")); flag_list->push_back(tsl::Flag( "xla_force_host_platform_device_count", int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count), diff --git a/xla/service/gpu/autotuner_util.h b/xla/service/gpu/autotuner_util.h index bba517d16566f0..2e8e2c8f03a344 100644 --- a/xla/service/gpu/autotuner_util.h +++ b/xla/service/gpu/autotuner_util.h @@ -101,6 +101,7 @@ class AutotuneConfig { bool should_init_buffers() const { return autotune_level_ >= 2; } bool should_reinit_output_buffer() const { return autotune_level_ >= 3; } bool should_check_correctness() const { return autotune_level_ >= 4; } + bool should_skip_wrong_results() const { return autotune_level_ >= 5; } bool should_crash_on_check_failure() const { return should_crash_on_check_failure_; } diff --git a/xla/service/gpu/buffer_comparator.cc b/xla/service/gpu/buffer_comparator.cc index 74abaec95a0fb4..8d8a7cac0595e5 100644 --- a/xla/service/gpu/buffer_comparator.cc +++ b/xla/service/gpu/buffer_comparator.cc @@ -49,6 +49,7 @@ using ComparisonKernelT = struct ComparisonParams { double relative_tol = 0.1; + bool verbose = true; const Shape* shape = nullptr; se::Stream* stream = nullptr; se::DeviceMemoryBase current{}; @@ -129,6 +130,7 @@ static absl::StatusOr HostCompare(const ComparisonParams& params) { return a; }; int differences_seen = 0; + for (int64_t i = 0; i < n && differences_seen < 10; ++i) { auto current_value = static_cast(host_current[i]); auto expected_value = static_cast(host_expected[i]); @@ -150,6 +152,7 @@ static absl::StatusOr HostCompare(const ComparisonParams& params) { std::abs(expected_value_canonical)) + 1) < params.relative_tol)) { + if (!params.verbose) return false; // Return immediately if not verbose. ++differences_seen; LOG(ERROR) << "Difference at " << i << ": " << current_value << ", expected " << expected_value; @@ -180,7 +183,8 @@ static absl::StatusOr CompareEqualParameterized( absl::StatusOr BufferComparator::CompareEqual( se::Stream* stream, se::DeviceMemoryBase current, se::DeviceMemoryBase expected) const { - ComparisonParams params{relative_tol_, &shape_, stream, current, expected}; + ComparisonParams params{relative_tol_, verbose_, &shape_, + stream, current, expected}; switch (shape_.element_type()) { #if GOOGLE_CUDA // not available for ROCm yet.. @@ -226,10 +230,9 @@ absl::StatusOr BufferComparator::CompareEqual( } } -BufferComparator::BufferComparator(const Shape& shape, - const HloModuleConfig& config, - double tolerance) - : shape_(shape), config_(config), relative_tol_(tolerance) { +BufferComparator::BufferComparator(const Shape& shape, double tolerance, + bool verbose) + : shape_(shape), relative_tol_(tolerance), verbose_(verbose) { // Normalize complex shapes: since we treat the passed array as a contiguous // storage it does not matter which dimension are we doubling. auto double_dim_size = [&]() { diff --git a/xla/service/gpu/buffer_comparator.h b/xla/service/gpu/buffer_comparator.h index b275524b6914ca..8e1019285e161e 100644 --- a/xla/service/gpu/buffer_comparator.h +++ b/xla/service/gpu/buffer_comparator.h @@ -34,8 +34,8 @@ class BufferComparator { BufferComparator(const BufferComparator&) = delete; BufferComparator(BufferComparator&&) = default; - BufferComparator(const Shape& shape, const HloModuleConfig& config, - double tolerance = 0.1); + explicit BufferComparator(const Shape& shape, double tolerance = 0.1, + bool verbose = true); // Returns true if the two buffers compare equal. The definition of "equal" // is: @@ -51,8 +51,8 @@ class BufferComparator { se::DeviceMemoryBase expected) const; private: Shape shape_; - HloModuleConfig config_; - double relative_tol_; + double relative_tol_; // relative tolerance for comparison + bool verbose_; // whether to print out error message on mismatch }; namespace buffer_comparator { diff --git a/xla/service/gpu/buffer_comparator_test.cc b/xla/service/gpu/buffer_comparator_test.cc index 05ab7e1009b396..cbef669abab7b3 100644 --- a/xla/service/gpu/buffer_comparator_test.cc +++ b/xla/service/gpu/buffer_comparator_test.cc @@ -75,7 +75,7 @@ class BufferComparatorTest : public testing::Test { ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {static_cast(current.size())}), - HloModuleConfig(), tolerance); + tolerance); return comparator .CompareEqual(stream.get(), current_buffer.memory(), expected_buffer.memory()) @@ -394,8 +394,7 @@ TEST_F(BufferComparatorTest, BF16) { stream_exec_->AllocateArray(element_count)); InitializeBuffer(stream.get(), BF16, &rng_state, rhs.memory()); - BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count}), - HloModuleConfig()); + BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count})); EXPECT_FALSE(comparator.CompareEqual(stream.get(), lhs.memory(), rhs.memory()) .value()); } diff --git a/xla/service/gpu/conv_algorithm_picker.cc b/xla/service/gpu/conv_algorithm_picker.cc index 5fc5746b32564c..bf6681f65de35e 100644 --- a/xla/service/gpu/conv_algorithm_picker.cc +++ b/xla/service/gpu/conv_algorithm_picker.cc @@ -732,8 +732,11 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( if (reference_result->has_value()) { XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2); + + const DebugOptions& debug_options = + runtime_arguments.hlo_module_config.debug_options(); BufferComparator comparator(runtime_arguments.rz_buffers.output_shape(), - runtime_arguments.hlo_module_config); + debug_options.xla_gpu_autotune_gemm_rtol()); for (int i = 0; i < result_buffers.size(); ++i) { absl::StatusOr compare_result = comparator.CompareEqual( stream, (*reference_result)->buffers[i], result_buffers[i]); @@ -747,8 +750,6 @@ absl::StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( // Possibly OOM. Propagate the error. return compare_result.status(); } - const DebugOptions& debug_options = - runtime_arguments.hlo_module_config.debug_options(); CHECK(!debug_options.xla_gpu_crash_on_verification_failures()); } else if (!compare_result.value()) { LOG(ERROR) diff --git a/xla/service/gpu/gemm_algorithm_picker.cc b/xla/service/gpu/gemm_algorithm_picker.cc index 07a83fda57bbbd..aac4aa2f21eb20 100644 --- a/xla/service/gpu/gemm_algorithm_picker.cc +++ b/xla/service/gpu/gemm_algorithm_picker.cc @@ -92,13 +92,17 @@ class GemmAutotuner { se::Stream* stream_ = nullptr; bool deterministic_ops_ = false; size_t solutions_limit_ = 0; + size_t num_algorithms_left_ = 0; public: explicit GemmAutotuner(const AutotuneConfig& autotune_config) : autotune_config_(autotune_config) {} + size_t num_algorithms_left() const { return num_algorithms_left_; } + absl::StatusOr operator()(const HloInstruction* gemm, const AutotuneCacheKey& key) { + num_algorithms_left_ = 0; if (autotune_config_.IsDeviceless()) { // Return empty result, will tune at runtime. return AutotuneResult{}; @@ -274,7 +278,11 @@ class GemmAutotuner { ShapeUtil::ByteSizeOf(output_shape))); } - BufferComparator comparator(output_shape, hlo_module_config); + // Do not print error messages if should_skip_wrong_results() is ON. + BufferComparator comparator( + output_shape, + hlo_module_config.debug_options().xla_gpu_autotune_gemm_rtol(), + /* verbose */ !autotune_config_.should_skip_wrong_results()); std::vector results; results.reserve(algorithms.size()); std::optional reference_algorithm; @@ -307,6 +315,7 @@ class GemmAutotuner { absl::Milliseconds(profile_result.elapsed_time_in_ms())); if (!autotune_config_.should_check_correctness()) { + num_algorithms_left_++; continue; } TF_ASSIGN_OR_RETURN( @@ -322,25 +331,35 @@ class GemmAutotuner { continue; } + num_algorithms_left_++; if (!reference_algorithm) { TF_RETURN_IF_ERROR(stream_->Memcpy(&reference_buffer, OutputBuffer(), OutputBuffer().size())); reference_algorithm = profile_result.algorithm(); - } else { - // Perform the comparison. - TF_ASSIGN_OR_RETURN( - bool outputs_match, - comparator.CompareEqual(stream_, /*current=*/OutputBuffer(), - /*expected=*/reference_buffer)); - if (!outputs_match) { - LOG(ERROR) << "Results mismatch between different GEMM algorithms. " - << "This is likely a bug/unexpected loss of precision."; - CHECK(!autotune_config_.should_crash_on_check_failure()); - - result.mutable_failure()->set_kind(AutotuneResult::WRONG_RESULT); - result.mutable_failure()->mutable_reference_gemm()->set_algorithm( - *reference_algorithm); + continue; + } + // Perform the comparison versus the reference algorithm. + TF_ASSIGN_OR_RETURN( + bool outputs_match, + comparator.CompareEqual(stream_, /*current=*/OutputBuffer(), + /*expected=*/reference_buffer)); + if (!outputs_match) { + LOG(ERROR) << "Results mismatch between different GEMM algorithms. " + << "This is likely a bug/unexpected loss of precision."; + CHECK(!autotune_config_.should_crash_on_check_failure()); + + // By default, autotuner does NOT really skip wrong results, but + // merely prints out the above error message: this may lead to a + // great confusion. When should_skip_wrong_results() is set to true, + // solutions with accuracy problems will be disqualified. + auto kind = AutotuneResult::WRONG_RESULT; + if (autotune_config_.should_skip_wrong_results()) { + kind = AutotuneResult::DISQUALIFIED; + num_algorithms_left_--; // Decrement again since we disqualified it. } + result.mutable_failure()->set_kind(kind); + result.mutable_failure()->mutable_reference_gemm()->set_algorithm( + *reference_algorithm); } } // for algorithms @@ -373,13 +392,15 @@ class GemmAutotuner { // Do Gemm Autotune without stream executor. Use results from autotune cache // only. absl::StatusOr RunOnInstruction(HloInstruction* gemm, - const AutotuneConfig& config) { + const AutotuneConfig& config, + size_t* num_algorithms_left) { VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString(); GpuBackendConfig gpu_config = gemm->backend_config().value(); GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config(); + *num_algorithms_left = 0; // Degenerate gemms replaced with memzero operation, no need to auto tune it. if (backend_config.alpha_real() == 0.0 && backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) { @@ -393,6 +414,7 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, AutotunerUtil::Autotune( gemm, config, [&] { return autotuner(gemm, key); })); + *num_algorithms_left = autotuner.num_algorithms_left(); auto old_algorithm = backend_config.selected_algorithm(); bool update_algorithm = IsCublasLtMatmulF8(*gemm) || @@ -434,11 +456,17 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, } absl::StatusOr RunOnComputation(HloComputation* computation, - AutotuneConfig config) { + AutotuneConfig config, + size_t* num_algorithms_left) { bool changed = false; + for (HloInstruction* instr : computation->instructions()) { if (IsCublasGemm(*instr)) { - TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, config)); + size_t num_left; + TF_ASSIGN_OR_RETURN(bool result, + RunOnInstruction(instr, config, &num_left)); + // Gathering statistics on the algorithms left after tuning (for testing) + *num_algorithms_left = std::max(*num_algorithms_left, num_left); changed |= result; } } @@ -453,6 +481,7 @@ absl::StatusOr GemmAlgorithmPicker::Run( XLA_SCOPED_LOGGING_TIMER( absl::StrCat("GemmAlgorithmPicker for ", module->name())); + num_algorithms_left_ = 0; if (module->config().debug_options().xla_gpu_autotune_level() == 0) { VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early"; return false; @@ -461,7 +490,8 @@ absl::StatusOr GemmAlgorithmPicker::Run( bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_)); + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_, + &num_algorithms_left_)); changed |= result; } return changed; diff --git a/xla/service/gpu/gemm_algorithm_picker.h b/xla/service/gpu/gemm_algorithm_picker.h index cc70e0d16a21af..be2686ddc93e86 100644 --- a/xla/service/gpu/gemm_algorithm_picker.h +++ b/xla/service/gpu/gemm_algorithm_picker.h @@ -50,6 +50,8 @@ class GemmAlgorithmPicker : public HloModulePass { absl::string_view name() const override { return "gemm-algorithm-picker"; } + size_t num_algorithms_left() const { return num_algorithms_left_; } + using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, @@ -57,6 +59,9 @@ class GemmAlgorithmPicker : public HloModulePass { private: AutotuneConfig config_; + // The number of valid algorithms used for autotuning (from the last call), + // to be used for testing purposes. + size_t num_algorithms_left_ = 0; }; } // namespace gpu diff --git a/xla/service/gpu/gemm_algorithm_picker_test.cc b/xla/service/gpu/gemm_algorithm_picker_test.cc index a9a1f12bd55d94..b16d9422ba472c 100644 --- a/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -111,6 +111,68 @@ TEST_P(GemmAlgorithmPickerTest, BlasGetVersion) { ASSERT_TRUE(!version.empty()); } +TEST_P(GemmAlgorithmPickerTest, SkipAlgorithmsWithAccuracyCheck) { + constexpr absl::string_view kHlo = R"( +HloModule module + +ENTRY main { + %arg0 = f32[100,100]{1,0} parameter(0) + %arg1 = f32[100,100]{1,0} parameter(1) + ROOT %dot = f32[100,100]{1,0} dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + auto module_cfg = GetModuleConfigForTest(); + auto debug_opts = module_cfg.debug_options(); + size_t num_left1 = 0, num_left2 = 0; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHlo, module_cfg)); + + { + // Run first with default settings (autotune level = 4), keep the number of + // algorithms left after autotuning + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), + module.get())); + + AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts}; + GemmAlgorithmPicker gpicker(cfg); + // Note that, we do not care if the algorithm index has been changed: + // the thing matters is the # of algorithms left after sorting out. + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get())); + num_left1 = gpicker.num_algorithms_left(); + if (num_left1 < 2) { + GTEST_SKIP() << "Too few algorithms left after the first step"; + } + } + + // Clear cache before the second run! + AutotunerUtil::ClearAutotuneResults(); + { + // Run once again but now with autotune level 5 and embarassingly tight + // rtol which shall disqualify most of the algorithms. + + // Note that, we have "two sources of truth" for GemmAlgorithmPicker: i.e., + // debug_options are used to initialize both 'HloModuleConfig' and also + // 'AutotuneConfig'. + debug_opts.set_xla_gpu_autotune_gemm_rtol(1e-12); + debug_opts.set_xla_gpu_autotune_level(5); + module->mutable_config().set_debug_options(debug_opts); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + RunHloPass(GemmRewriter(gpu_comp(), /*toolkit_version=*/12040), + module.get())); + + AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts}; + GemmAlgorithmPicker gpicker(cfg); + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get())); + num_left2 = gpicker.num_algorithms_left(); + } + // Assert that we have fewer algorithms left after the second run. + ASSERT_TRUE(num_left1 > num_left2); +} + TEST_P(GemmAlgorithmPickerTest, SetAlgorithm) { constexpr absl::string_view kHlo = R"( HloModule module diff --git a/xla/service/gpu/gemm_fusion_autotuner.cc b/xla/service/gpu/gemm_fusion_autotuner.cc index 6359e5f28dbcb5..179f7387b9b83b 100644 --- a/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/gemm_fusion_autotuner.cc @@ -887,7 +887,7 @@ absl::StatusOr> GemmFusionAutotunerImpl::Profile( const HloInstruction& root = *fusion_computation->root_instruction(); BufferComparator comparator(root.shape(), - fusion_computation->parent()->config()); + debug_options_.xla_gpu_autotune_gemm_rtol()); TF_ASSIGN_OR_RETURN(auto rz_buffers, RedzoneBuffers::FromInstruction( diff --git a/xla/service/gpu/triton_fusion_numerics_verifier.cc b/xla/service/gpu/triton_fusion_numerics_verifier.cc index 11aa1e8a966013..75c43feadd605c 100644 --- a/xla/service/gpu/triton_fusion_numerics_verifier.cc +++ b/xla/service/gpu/triton_fusion_numerics_verifier.cc @@ -114,7 +114,8 @@ absl::Status CompareBuffers(const ScopedShapedBuffer& current, const ScopedShapedBuffer& expected, const Shape& shape, const HloModuleConfig& config, se::Stream* stream) { - BufferComparator comparator(shape, config); + BufferComparator comparator( + shape, config.debug_options().xla_gpu_autotune_gemm_rtol()); TF_ASSIGN_OR_RETURN(bool outputs_match, comparator.CompareEqual(stream, current.root_buffer(), expected.root_buffer())); diff --git a/xla/xla.proto b/xla/xla.proto index 231c4127528b5d..dc232e2941edf2 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -870,7 +870,10 @@ message DebugOptions { // all-to-all-done = f32[1,4,8]{1,0,2} all-to-all-done(all-to-all-start) bool xla_syntax_sugar_async_ops = 315; - // Next id: 316 + // Relative precision for comparing different GEMM solutions + float xla_gpu_autotune_gemm_rtol = 316; + + // Next id: 317 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 7e033c9e827edacdda722f96e088a9b18c9d1bdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Fri, 19 Jul 2024 02:44:18 -0700 Subject: [PATCH 008/376] [XLA:CPU] Support `SliceToDynamic` custom call thunk Additionally turn on the `set_dimension_size` test for thunk runtime (`set_dimension_size` op is rewritten as `SliceToDynamic`, that's why it has been turned off so far). PiperOrigin-RevId: 653940994 --- xla/service/cpu/ir_emitter.cc | 31 ++++++++++++++++++++++--------- xla/service/cpu/ir_emitter.h | 6 ++++++ xla/service/cpu/ir_emitter2.cc | 19 +++++++++++++++++++ xla/service/cpu/ir_emitter2.h | 4 ++++ xla/service/cpu/thunk_emitter.cc | 14 +++++++++++++- xla/service/cpu/thunk_emitter.h | 3 +++ xla/tests/BUILD | 2 +- 7 files changed, 68 insertions(+), 11 deletions(-) diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index a3bbc3b76ef4e2..183479ba0fa991 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -2451,15 +2451,16 @@ absl::Status IrEmitter::HandleCall(HloInstruction* call) { return absl::OkStatus(); } -absl::Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); +absl::Status IrEmitter::EmitSliceToDynamic( + const HloInstruction* hlo, absl::Span source_arrays, + const llvm_ir::IrArray& target_array) { std::vector dynamic_dims; int32_t raw_data_size = ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape())); - llvm::Value* dest_buffer = GetEmittedValueFor(hlo); + llvm::Value* dest_buffer = target_array.GetBasePointer(); for (int64_t i = 1; i < hlo->operand_count(); ++i) { const int64_t dim_index = i - 1; - llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i)); + llvm::Value* source_buffer = source_arrays[i].GetBasePointer(); llvm::LoadInst* dyn_dim_size = Load(IrShapeType(hlo->operand(i)->shape()), source_buffer, "dyn_dim_size"); @@ -2472,7 +2473,6 @@ absl::Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { "i64_dyn_dim_size")); } - llvm_ir::IrArray data_array = GetIrArrayFor(hlo); // Pseudo code for sliceToDynamic: // // for (index i in dynamic_dim) @@ -2481,19 +2481,32 @@ absl::Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { auto loop_body_emitter = [&](const llvm_ir::IrArray::Index& array_index) -> absl::Status { llvm::Value* source_element = - GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, b()); + source_arrays[0].EmitReadArrayElement(array_index, b()); llvm::Value* linear_index = array_index.Linearize(dynamic_dims, b()); // Delinearize the index based on the static shape. - llvm_ir::IrArray::Index dest_index(linear_index, data_array.GetShape(), + llvm_ir::IrArray::Index dest_index(linear_index, target_array.GetShape(), b()); - data_array.EmitWriteArrayElement(dest_index, source_element, b()); + target_array.EmitWriteArrayElement(dest_index, source_element, b()); return absl::OkStatus(); }; - return llvm_ir::LoopEmitter(loop_body_emitter, data_array.GetShape(), + return llvm_ir::LoopEmitter(loop_body_emitter, target_array.GetShape(), dynamic_dims, b()) .EmitLoop(IrName(hlo)); } +absl::Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + llvm_ir::IrArray target_array = GetIrArrayFor(hlo); + + std::vector source_arrays; + source_arrays.reserve(hlo->operand_count()); + for (auto operand : hlo->operands()) { + source_arrays.push_back(GetIrArrayFor(operand)); + } + + return EmitSliceToDynamic(hlo, source_arrays, target_array); +} + absl::Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index a5db74875f6840..37da729c88be4e 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -511,6 +511,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); + // Emit slice-to-dynamic. + absl::Status EmitSliceToDynamic( + const HloInstruction* hlo, + absl::Span source_arrays, + const llvm_ir::IrArray& target_array); + // Emits printing during the execution. llvm::Value* EmitPrintf(absl::string_view fmt, absl::Span arguments); diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index ed2a7d10adf299..86363fdc2a2dd4 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -449,6 +449,25 @@ absl::StatusOr IrEmitter2::EmitDotFusionHostKernel( se::ThreadDim()}); } +absl::StatusOr IrEmitter2::EmitSliceToDynamicHostKernel( + const HloInstruction* instr) { + VLOG(2) << "Emit slice-to-dynamic host kernel: " << instr->name(); + + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(instr)); + llvm::IRBuilder<> ir_builder(module_->getContext()); + ir_builder.SetInsertPoint( + kernel_prototype.function->getEntryBlock().getTerminator()); + + llvm_ir::IrArray output_array = kernel_prototype.results[0]; + auto guard = nested_ir_emitter_->WithBuilder(ir_builder); + TF_RETURN_IF_ERROR(nested_ir_emitter_->EmitSliceToDynamic( + instr, kernel_prototype.arguments, output_array)); + return kernels_.emplace_back( + KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), + se::ThreadDim()}); +} + absl::StatusOr IrEmitter2::EmitSelectAndScatterHostKernel(const HloInstruction* instr) { TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, diff --git a/xla/service/cpu/ir_emitter2.h b/xla/service/cpu/ir_emitter2.h index c210a11aca1774..65f07836e04ca0 100644 --- a/xla/service/cpu/ir_emitter2.h +++ b/xla/service/cpu/ir_emitter2.h @@ -144,6 +144,10 @@ class IrEmitter2 { absl::StatusOr EmitDotFusionHostKernel( const HloFusionInstruction* fusion); + // Emits a host kernel for the given slice-to-dynamic instruction. + absl::StatusOr EmitSliceToDynamicHostKernel( + const HloInstruction* instr); + // Emits a host kernel for the given select-and-scatter instruction. absl::StatusOr EmitSelectAndScatterHostKernel( const HloInstruction* instr); diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index a39f31aada31f5..2d6f7f68a71d8e 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -897,7 +897,6 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( // TODO(penporn): Support these existing targets. auto custom_call_target = custom_call->custom_call_target(); if (custom_call_target == "PadToStatic" || - custom_call_target == "SliceToDynamic" || custom_call_target == "__onednn$matmul" || custom_call_target == "__onednn$softmax" || custom_call_target == "__onednn$layernorm" || @@ -907,6 +906,8 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( } if (custom_call_target == "TopK") { return EmitTopKThunk(custom_call); + } else if (custom_call_target == "SliceToDynamic") { + return EmitSliceToDynamicThunk(instruction); } // Check the API version. @@ -927,6 +928,17 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( backend_config, version); } +absl::StatusOr ThunkEmitter::EmitSliceToDynamicThunk( + const HloInstruction* instruction) { + TF_ASSIGN_OR_RETURN(auto kernel, + ir_emitter_.EmitSliceToDynamicHostKernel(instruction)); + TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); + + return ThunkSequence::Of( + ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, + kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); +} + absl::StatusOr ThunkEmitter::EmitSelectAndScatterThunk( const HloInstruction* instruction) { TF_ASSIGN_OR_RETURN(auto kernel, diff --git a/xla/service/cpu/thunk_emitter.h b/xla/service/cpu/thunk_emitter.h index de73bdc26f1066..a829b4c159328e 100644 --- a/xla/service/cpu/thunk_emitter.h +++ b/xla/service/cpu/thunk_emitter.h @@ -150,6 +150,9 @@ class ThunkEmitter { absl::StatusOr EmitCustomCallThunk( const HloInstruction* instruction); + absl::StatusOr EmitSliceToDynamicThunk( + const HloInstruction* instruction); + absl::StatusOr EmitSelectAndScatterThunk( const HloInstruction* instruction); diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 9f5f379e5b158a..5e07cdb79001ff 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2937,7 +2937,7 @@ xla_test( xla_test( name = "set_dimension_size_test", srcs = ["set_dimension_size_test.cc"], - # TODO(abanas): add test_xla_cpu_thunks tag when 'SliceToDynamic' custom call is supported. + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", # fixdeps: keep From 10cca5edcdc3b64c1c178ac345b228f9f654340c Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 19 Jul 2024 03:20:13 -0700 Subject: [PATCH 009/376] Fix constraint check in the shared memory phase of row reductions. The problem is that indexing map creation shrinks the size of the d0 dimension (thread ID), which in the case of the test I added, results in only the first 20 threads producing a value for the shuffle. But then we do an unmasked shuffle with the entire warp, which means we get some undefined values. PiperOrigin-RevId: 653948826 --- xla/service/gpu/fusions/reduction_mlir.cc | 10 ++++++++-- .../gpu/fusions/reduction_mlir_test.cc | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/fusions/reduction_mlir.cc b/xla/service/gpu/fusions/reduction_mlir.cc index a6bc7f566e48a1..a8697f259dbc63 100644 --- a/xla/service/gpu/fusions/reduction_mlir.cc +++ b/xla/service/gpu/fusions/reduction_mlir.cc @@ -517,9 +517,15 @@ mlir::ValueRange MlirReductionFusion::EmitterState::ReduceViaSharedMemory( owner.GetSharedMemoryReductionReadMap(builder.getContext()); auto loop_indexing = read_indexing; // All threads must participate in the shuffle, so we clear the constraints - // for the iteration. Otherwise, some threads might not be part of the loop. - // The constraints are still checked inside the loop. + // for the iteration. Otherwise, some threads might not be part of the loop, + // resulting in incorrect results for the warp shuffle. + // The constraints are still checked inside the loop in the + // PredicatedExtractOp. loop_indexing.ClearConstraints(); + // The constraints may have reduced the upper bound of the dimension. If + // that's the case, we reset it to a multiple of the warp size. + auto& bound = loop_indexing.GetMutableDimensionBound(0); + bound.upper = RoundUpTo(bound.upper + 1, WarpSize()) - 1; auto tiles = WriteToSharedMemory(reductions, per_thread.reduction_scalars); return mlir_converter::EmitLoopNest( diff --git a/xla/service/gpu/fusions/reduction_mlir_test.cc b/xla/service/gpu/fusions/reduction_mlir_test.cc index 8a1f880dc64a68..9fac8cc5f34d24 100644 --- a/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -860,6 +860,25 @@ TEST_F(MlirMultiRowReductionTest, VectorizedX4Correctness) { RunAndCompareNoHloPasses(kMultiRowReductionX2VectorX4, ErrorSpec{1e-3})); } +TEST_F(MlirRowReductionTest, LargeToUnit) { + // Regression test for a bug where not all threads in the warp produced a + // valid value for the final warp shuffle. + constexpr auto kHloString = R"( + and { + p0 = pred[] parameter(0) + p1 = pred[] parameter(1) + ROOT and = pred[] and(p0, p1) + } + + %fused_reduce { + c1 = pred[] constant(true) + p0 = pred[10000] broadcast(c1), dimensions={} + ROOT reduce = pred[] reduce(p0, c1), dimensions={0}, to_apply=and + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + } // namespace } // namespace gpu } // namespace xla From 033e7404811c9f1a34774ab6934ef6baa80a4e09 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 19 Jul 2024 03:33:18 -0700 Subject: [PATCH 010/376] [XLA:GPU] Change the format of the Autotuning key. PiperOrigin-RevId: 653951913 --- docs/tools.md | 4 +- xla/service/gpu/BUILD | 9 ++++ xla/service/gpu/autotuner_util.cc | 44 +++++++++++++++++++ xla/service/gpu/autotuner_util.h | 30 +++++++------ xla/service/gpu/autotuner_util_test.cc | 34 +++++++++++++- xla/service/gpu/conv_algorithm_picker.cc | 3 +- xla/service/gpu/gemm_algorithm_picker_test.cc | 2 +- xla/service/gpu/gemm_fusion_autotuner_test.cc | 20 ++++----- xla/service/gpu/gpu_compiler.cc | 5 +-- xla/service/gpu/gpu_compiler_test.cc | 2 - .../gpu_compiler_test_autotune_db.textproto | 20 ++------- .../gpu/tests/test_autotune_cache.textproto | 2 +- ...aot_compile_test_autotune_results.prototxt | 4 +- 13 files changed, 126 insertions(+), 53 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index f6d4a3d793f2d4..48a8099d67fd5d 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -129,9 +129,9 @@ The autotune file is text serialization of `autotune_results.proto`, with example looking like: ``` -version: 2 +version: 3 results { - device: "sm_8.0 with 42331013120B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" + device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB" hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}" result { run_time { diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 4be7457e43992a..6c9b9f34b2ffd6 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -587,6 +587,7 @@ xla_test( "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", @@ -5563,6 +5564,11 @@ xla_cc_test( xla_cc_test( name = "autotuner_util_test", srcs = if_cuda_is_configured(["autotuner_util_test.cc"]), + data = [ + "//xla/tools/hlo_opt:gpu_specs/a100_sxm_40.txtpb", + "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb", + "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb", + ], deps = if_cuda_is_configured([ # keep sorted ":autotuner_util", @@ -5571,6 +5577,8 @@ xla_cc_test( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor/host:host_platform", @@ -5592,6 +5600,7 @@ xla_cc_test( "@tsl//tsl/platform:status", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", ]) + [ "//xla/tests:xla_internal_test_main", # Keep outside GPU guard ], diff --git a/xla/service/gpu/autotuner_util.cc b/xla/service/gpu/autotuner_util.cc index 56fc9b3b33030d..93c946f57b6462 100644 --- a/xla/service/gpu/autotuner_util.cc +++ b/xla/service/gpu/autotuner_util.cc @@ -17,11 +17,13 @@ limitations under the License. #include #include +#include #include #include #include #include #include +#include #include "absl/base/const_init.h" #include "absl/base/thread_annotations.h" @@ -46,6 +48,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/stream.h" @@ -311,6 +314,47 @@ AutotuneCacheKey::AutotuneCacheKey(absl::string_view model_str, const HloInstruction& instr) : AutotuneCacheKey(model_str, ToCanonicalString(&instr)) {} +/*static*/ std::string AutotuneCacheKey::DeviceDescriptionToCacheKey( + const se::DeviceDescription& device_description) { + std::string compute_capability; + if (auto* ccc = std::get_if( + &device_description.gpu_compute_capability())) { + compute_capability = absl::StrCat("CUDA: ", ccc->major, ".", ccc->minor); + } else { + auto* rcc = std::get_if( + &device_description.gpu_compute_capability()); + CHECK(rcc != nullptr) << "Unknown compute capability type"; + compute_capability = absl::StrCat("ROCM: ", rcc->gfx_version()); + } + + // The string below should include only as much information as is needed to + // make it a valid key. Information that should not be included is: + // - specs that are directly derivable from the compute capability, e.g. + // shared memory size. For NVIDIA GPUs, you can see what is derivable from + // the SM version here: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability + // - specs that are irrelevant for autotuning. E.g. the total available memory + // on a device is not relevant, because by itself, it does not affect the + // performance of single kernels. + // + // See b/344573710 for some discussion. + + double memory_bandwidth = device_description.memory_bandwidth() / 1e9; + // Round the memory bandwidth to make the final string nicer to read. + // This will also cause minute differences in bandwidth to yield the same + // cache key, but that's fine, since the difference is inconsequential. + memory_bandwidth = std::round(memory_bandwidth); + + constexpr double kBytesPerMegabyte = 1 << 20; + double l2_cache_size = device_description.l2_cache_size() / kBytesPerMegabyte; + + return absl::StrCat(compute_capability, + ", Cores: ", device_description.core_count(), + ", GPU clock: ", device_description.clock_rate_ghz(), + " GHz, Memory bandwidth: ", memory_bandwidth, + " GB/s, L2 cache: ", l2_cache_size, " MB"); +} + namespace { absl::StatusOr> TryFindInCache( const AutotuneCacheKey& key, absl::string_view cache_dir) diff --git a/xla/service/gpu/autotuner_util.h b/xla/service/gpu/autotuner_util.h index 2e8e2c8f03a344..4634fc21b44fa4 100644 --- a/xla/service/gpu/autotuner_util.h +++ b/xla/service/gpu/autotuner_util.h @@ -53,19 +53,17 @@ struct DeviceConfig { }; struct DevicelessConfig { - // The human-readable description of the device. It can be found by using - // stream_exec->GetDeviceDescription().model_str() when the stream executor - // is available. - std::string model_str; - - // A field to determine the architecture of the device. We only pick an - // algorithm for non-Ampere architectures. - se::GpuComputeCapability gpu_compute_capability{ - se::CudaComputeCapability{0, 0}}; + // The device description of the target device. + se::DeviceDescription device_description; }; class AutotuneCacheKey { public: + AutotuneCacheKey(const se::DeviceDescription& device_description, + const HloInstruction& instruction) + : AutotuneCacheKey(DeviceDescriptionToCacheKey(device_description), + instruction.ToString()) {} + AutotuneCacheKey(absl::string_view model_str, const HloInstruction& instruction); @@ -91,6 +89,9 @@ class AutotuneCacheKey { hlo_canonical_); } + static std::string DeviceDescriptionToCacheKey( + const se::DeviceDescription& device_description); + private: std::string model_str_; std::string hlo_canonical_; @@ -133,13 +134,15 @@ class AutotuneConfig { autotune_cache_dir_( debug_options.xla_gpu_per_fusion_autotune_cache_dir()) {} - absl::string_view GetModelStr() const { + std::string GetModelStr() const { if (auto deviceless_config = std::get_if(&config_)) { - return deviceless_config->model_str; + return AutotuneCacheKey::DeviceDescriptionToCacheKey( + deviceless_config->device_description); } const auto& device_config = std::get(config_); - return device_config.stream_exec->GetDeviceDescription().model_str(); + return AutotuneCacheKey::DeviceDescriptionToCacheKey( + device_config.stream_exec->GetDeviceDescription()); } se::StreamExecutor* GetExecutor() const { @@ -169,7 +172,8 @@ class AutotuneConfig { if (auto c = std::get_if(&config_)) { return c->stream_exec->GetDeviceDescription().gpu_compute_capability(); } - return std::get(config_).gpu_compute_capability; + return std::get(config_) + .device_description.gpu_compute_capability(); } bool IsDeviceless() const { diff --git a/xla/service/gpu/autotuner_util_test.cc b/xla/service/gpu/autotuner_util_test.cc index 508fa42bb00afc..37fb56ed67fb83 100644 --- a/xla/service/gpu/autotuner_util_test.cc +++ b/xla/service/gpu/autotuner_util_test.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" namespace xla { namespace gpu { @@ -74,7 +75,7 @@ ENTRY e { static constexpr absl::string_view kResultText = R"( version: 3 results { - device: "sm_8.0 with 42331013120B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" + device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB" hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}" result { run_time { @@ -422,6 +423,37 @@ TEST_F(FileBasedCacheTest, RepeatedAddResultDoesNotWriteTheFileAgain) { EXPECT_EQ(Read(cache_file_path_), kPlaceholderContent); } +TEST(AutotuneCacheKeyTest, DeviceDescriptionToCacheKey) { + auto device_description = + [](absl::string_view spec_file_name) -> se::DeviceDescription { + se::GpuTargetConfigProto proto; + std::string spec_string; + CHECK_OK(tsl::ReadFileToString( + tsl::Env::Default(), + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "hlo_opt", + "gpu_specs", spec_file_name), + &spec_string)); + EXPECT_TRUE( + tsl::protobuf::TextFormat::ParseFromString(spec_string, &proto)); + return se::DeviceDescription(proto.gpu_device_info()); + }; + + EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey( + device_description("a100_sxm_40.txtpb")), + "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: " + "1555 GB/s, L2 cache: 40 MB"); + + EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey( + device_description("a100_sxm_80.txtpb")), + "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: " + "2039 GB/s, L2 cache: 40 MB"); + + EXPECT_EQ(AutotuneCacheKey::DeviceDescriptionToCacheKey( + device_description("mi200.txtpb")), + "ROCM: gfx90a, Cores: 110, GPU clock: 1.7 GHz, Memory bandwidth: " + "1638 GB/s, L2 cache: 8 MB"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/conv_algorithm_picker.cc b/xla/service/gpu/conv_algorithm_picker.cc index bf6681f65de35e..fc93de8b744698 100644 --- a/xla/service/gpu/conv_algorithm_picker.cc +++ b/xla/service/gpu/conv_algorithm_picker.cc @@ -446,8 +446,7 @@ GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction( // Get canonical HLO. std::string canonical_hlo( - AutotuneCacheKey(config.GetExecutor()->GetDeviceDescription().model_str(), - *instr) + AutotuneCacheKey(config.GetExecutor()->GetDeviceDescription(), *instr) .GetHlo()); TF_ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr)); diff --git a/xla/service/gpu/gemm_algorithm_picker_test.cc b/xla/service/gpu/gemm_algorithm_picker_test.cc index b16d9422ba472c..3017af4bef9528 100644 --- a/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -274,7 +274,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg)); changed = false; - DevicelessConfig deviceless_config{gpu_device_desc().model_str(), gpu_comp()}; + DevicelessConfig deviceless_config{gpu_device_desc()}; AutotuneConfig deviceless_cfg{deviceless_config, opts}; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GemmRewriter(gpu_comp(), diff --git a/xla/service/gpu/gemm_fusion_autotuner_test.cc b/xla/service/gpu/gemm_fusion_autotuner_test.cc index 8dce9de60e7778..8cb7e8dc87e229 100644 --- a/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -50,6 +50,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_description.pb.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" @@ -244,7 +245,11 @@ absl::StatusOr> GetPossibleMatmulAutotuneConfigs( const HloDotInstruction& dot, const se::CudaComputeCapability& compute_capability, const int32_t toolkit_version, const DebugOptions& debug_options) { - DevicelessConfig test_config{/*model_str=*/"", compute_capability}; + se::GpuDeviceInfoProto deviceless_proto; + auto ccc = deviceless_proto.mutable_cuda_compute_capability(); + ccc->set_major(compute_capability.major); + ccc->set_minor(compute_capability.minor); + DevicelessConfig test_config{se::DeviceDescription{deviceless_proto}}; AutotuneConfig autotune_config{test_config, debug_options}; GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, debug_options, nullptr); @@ -776,15 +781,10 @@ ENTRY e { DebugOptions opts; MultiProcessKeyValueStore key_value_store; pipeline.AddPass( - AutotuneConfig{DevicelessConfig{backend() - .default_stream_executor() - ->GetDeviceDescription() - .model_str(), - backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability()}, - opts}, + AutotuneConfig{ + DevicelessConfig{ + backend().default_stream_executor()->GetDeviceDescription()}, + opts}, GetToolkitVersion(), &thread_pool, key_value_store); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index d0ba0777e8776c..550476b74201bf 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -299,9 +299,8 @@ absl::StatusOr GetAutotuneConfig( return AutotuneConfig{DeviceConfig{stream_exec, options.device_allocator}, debug_options}; } - return AutotuneConfig{ - DevicelessConfig{gpu_target_config.device_description_str}, - debug_options}; + return AutotuneConfig{DevicelessConfig{gpu_target_config.device_description}, + debug_options}; } se::GpuComputeCapability GetGpuVersion(const se::StreamExecutor* stream_exec) { diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 287813b394c745..ff766823c8fe2c 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -391,8 +391,6 @@ ENTRY main { TEST_F(GpuCompilerTest, GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) { - GTEST_SKIP() << "TODO(b/344573710): this test is flaky, disable it " - << " until flakiness is fixed."; auto cc = backend() .default_stream_executor() ->GetDeviceDescription() diff --git a/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto index 39874deb65565e..ecdc8e089ca80a 100644 --- a/xla/service/gpu/gpu_compiler_test_autotune_db.textproto +++ b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto @@ -1,6 +1,6 @@ version: 3 results { - device: "sm_8.0 with 42296475648B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" + device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB" hlo: "{\n tmp_0 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(0)\n tmp_1 = f32[1,4,32,1024,1024]{4,3,2,1,0} convert(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_0)\n tmp_2 = bf16[] constant({...})\n tmp_3 = bf16[1,4,32,1024,1024]{4,3,2,1,0} broadcast(bf16[] tmp_2), dimensions={}\n tmp_4 = f32[1,4,32,1024,1024]{4,3,2,1,0} convert(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_3)\n tmp_5 = f32[1,4,32,1024,1024]{4,3,2,1,0} multiply(f32[1,4,32,1024,1024]{4,3,2,1,0} tmp_1, f32[1,4,32,1024,1024]{4,3,2,1,0} tmp_4)\n tmp_6 = bf16[1,4,32,1024,1024]{4,3,2,1,0} convert(f32[1,4,32,1024,1024]{4,3,2,1,0} tmp_5)\n tmp_7 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_6)\n tmp_8 = bf16[4,32,1024,1024]{3,2,1,0} transpose(bf16[4,32,1024,1024]{3,2,1,0} tmp_7), dimensions={0,1,3,2}\n tmp_9 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[4,32,1024,1024]{3,2,1,0} tmp_8)\n tmp_10 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(1)\n tmp_11 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_10)\n tmp_12 = bf16[128,1024,1024]{2,1,0} dot(bf16[128,1024,1024]{2,1,0} tmp_9, bf16[128,1024,1024]{2,1,0} tmp_11), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}\n ROOT tmp_13 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[128,1024,1024]{2,1,0} tmp_12)\n}" result { gemm { @@ -12,19 +12,7 @@ results { } } results { - device: "sm_8.0 with 42298834944B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" - hlo: "{\n tmp_0 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(0)\n tmp_1 = f32[1,4,32,1024,1024]{4,3,2,1,0} convert(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_0)\n tmp_2 = bf16[] constant({...})\n tmp_3 = bf16[1,4,32,1024,1024]{4,3,2,1,0} broadcast(bf16[] tmp_2), dimensions={}\n tmp_4 = f32[1,4,32,1024,1024]{4,3,2,1,0} convert(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_3)\n tmp_5 = f32[1,4,32,1024,1024]{4,3,2,1,0} multiply(f32[1,4,32,1024,1024]{4,3,2,1,0} tmp_1, f32[1,4,32,1024,1024]{4,3,2,1,0} tmp_4)\n tmp_6 = bf16[1,4,32,1024,1024]{4,3,2,1,0} convert(f32[1,4,32,1024,1024]{4,3,2,1,0} tmp_5)\n tmp_7 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_6)\n tmp_8 = bf16[4,32,1024,1024]{3,2,1,0} transpose(bf16[4,32,1024,1024]{3,2,1,0} tmp_7), dimensions={0,1,3,2}\n tmp_9 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[4,32,1024,1024]{3,2,1,0} tmp_8)\n tmp_10 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(1)\n tmp_11 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_10)\n tmp_12 = bf16[128,1024,1024]{2,1,0} dot(bf16[128,1024,1024]{2,1,0} tmp_9, bf16[128,1024,1024]{2,1,0} tmp_11), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}\n ROOT tmp_13 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[128,1024,1024]{2,1,0} tmp_12)\n}" - result { - run_time { - nanos: 1 - } - gemm { - algorithm: -1 - } - } -} -results { - device: "sm_8.0 with 42298834944B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" + device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB" hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}" result { run_time { @@ -36,7 +24,7 @@ results { } } results { - device: "sm_9.0 with 84942979072B RAM, 132 cores, 1980000KHz clock, 2619000KHz mem clock, 52428800B L2$" + device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" hlo: "{\n tmp_0 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(0)\n tmp_1 = bf16[] constant({...})\n tmp_2 = bf16[1,4,32,1024,1024]{4,3,2,1,0} broadcast(bf16[] tmp_1), dimensions={}\n tmp_3 = bf16[1,4,32,1024,1024]{4,3,2,1,0} multiply(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_0, bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_2)\n tmp_4 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_3)\n tmp_5 = bf16[4,32,1024,1024]{3,2,1,0} transpose(bf16[4,32,1024,1024]{3,2,1,0} tmp_4), dimensions={0,1,3,2}\n tmp_6 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[4,32,1024,1024]{3,2,1,0} tmp_5)\n tmp_7 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(1)\n tmp_8 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_7)\n tmp_9 = bf16[128,1024,1024]{2,1,0} dot(bf16[128,1024,1024]{2,1,0} tmp_6, bf16[128,1024,1024]{2,1,0} tmp_8), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}\n ROOT tmp_10 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[128,1024,1024]{2,1,0} tmp_9)\n}" result { gemm { @@ -48,7 +36,7 @@ results { } } results { - device: "sm_9.0 with 84942979072B RAM, 132 cores, 1980000KHz clock, 2619000KHz mem clock, 52428800B L2$" + device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" hlo: "(bf16[128,1024,1024]{2,1,0}, s8[33554432]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}" result { gemm { diff --git a/xla/service/gpu/tests/test_autotune_cache.textproto b/xla/service/gpu/tests/test_autotune_cache.textproto index b20a9d20ece50a..a1936e3b6644b1 100644 --- a/xla/service/gpu/tests/test_autotune_cache.textproto +++ b/xla/service/gpu/tests/test_autotune_cache.textproto @@ -14,7 +14,7 @@ version: 3 results { - device: "sm_8.0 with 42331013120B RAM, 108 cores, 1410000KHz clock, 1215000KHz mem clock, 41943040B L2$" + device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB" hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}" result { run_time { diff --git a/xla/service/xla_aot_compile_test_autotune_results.prototxt b/xla/service/xla_aot_compile_test_autotune_results.prototxt index 592ea2a9e185fb..1901cf2eecebd2 100644 --- a/xla/service/xla_aot_compile_test_autotune_results.prototxt +++ b/xla/service/xla_aot_compile_test_autotune_results.prototxt @@ -14,7 +14,7 @@ version: 3 results { - device: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$" + device: "CUDA: 6.0, Cores: 56, GPU clock: 1.4805 GHz, Memory bandwidth: 732 GB/s, L2 cache: 4 MB" hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"9\",\"rhs_stride\":\"9\",\"grad_x\":false,\"grad_y\":false},\"force_earliest_schedule\":false}" result { gemm { @@ -23,7 +23,7 @@ results { } } results { - device: "sm_6.0 with 17071734784B RAM, 56 cores, 1480500KHz clock, 715000KHz mem clock, 4194304B L2$" + device: "CUDA: 6.0, Cores: 56, GPU clock: 1.4805 GHz, Memory bandwidth: 732 GB/s, L2 cache: 4 MB" hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,2,4,4]{3,2,1,0}, f32[1,2,3,2]{3,2,1,0}), window={size=3x2}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convForward\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"cudnn_conv_backend_config\":{\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0,\"leakyrelu_alpha\":0},\"force_earliest_schedule\":false}" result { run_time { From 55e262703aa6717a7817f5557a411e7796d00606 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 19 Jul 2024 04:06:35 -0700 Subject: [PATCH 011/376] Remove some debug logging. Probably accidentally submitted at some point. PiperOrigin-RevId: 653959438 --- xla/mlir_hlo/mhlo/IR/hlo_ops.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index e3a27323046756..0de939289537b9 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -1171,9 +1171,7 @@ LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands, }; SmallVector shapeValues; auto getSliceDim = [&sliceSizes](int64_t index) -> Value { - llvm::errs() << "ABOUT TO FAIL\n"; auto ret = sliceSizes[index]; - llvm::errs() << "DID NOT FAIL\n"; return ret; }; hlo::reifyGatherDimSizes(resultRank, getStartIndicesDim, getSliceDim, From 2c1812d5958aa285e69fa0e54502eb103d4374eb Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 19 Jul 2024 04:13:22 -0700 Subject: [PATCH 012/376] PR #14968: [GPU] Enable sharding of autotuning by default. Imported from GitHub PR https://github.com/openxla/xla/pull/14968 Requires https://github.com/openxla/xla/pull/14881 @PatriosTheGreat Copybara import of the project: -- 94562fa1e73d8031832c0c0ed78b064cc7248aea by Ilia Sergachev : [GPU] Enable sharding of autotuning by default. Merging this change closes #14968 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/14968 from openxla:enable_sharded_autotuning 94562fa1e73d8031832c0c0ed78b064cc7248aea PiperOrigin-RevId: 653961038 --- xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 05a8c216e613bb..6046ff64ed6e3a 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -273,7 +273,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_use_shardy(false); - opts.set_xla_gpu_shard_autotuning(false); + opts.set_xla_gpu_shard_autotuning(true); opts.set_xla_syntax_sugar_async_ops(false); From 78a16ba2fa33a13e44e2162c3dbe065150168915 Mon Sep 17 00:00:00 2001 From: Patrick Toulme <135739773+ptoulme-aws@users.noreply.github.com> Date: Fri, 19 Jul 2024 04:47:46 -0700 Subject: [PATCH 013/376] PR #15002: Add unique channel id enforcer pass Imported from GitHub PR https://github.com/openxla/xla/pull/15002 We have found it is not guaranteed after all transformations, partitioning, while loop unrolling etc that all channel ids will be unique. Rather than debug this throughout XLA it is simpler to just add a pass that mandates unique channel ids, and changes channel ids to make them unique. Issue: #14600 Copybara import of the project: -- 476473182f6211b298b6d88d6a06bb128b7978ed by ptoulme-aws : Add unique channel id enforcer pass Merging this change closes #15002 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15002 from ptoulme-aws:unique_channel_id_new 476473182f6211b298b6d88d6a06bb128b7978ed PiperOrigin-RevId: 653968079 --- xla/service/BUILD | 23 ++++ xla/service/gpu/BUILD | 1 + xla/service/gpu/gpu_compiler.cc | 2 + xla/service/unique_channel_id_enforcer.cc | 56 +++++++++ xla/service/unique_channel_id_enforcer.h | 43 +++++++ .../unique_channel_id_enforcer_test.cc | 108 ++++++++++++++++++ 6 files changed, 233 insertions(+) create mode 100644 xla/service/unique_channel_id_enforcer.cc create mode 100644 xla/service/unique_channel_id_enforcer.h create mode 100644 xla/service/unique_channel_id_enforcer_test.cc diff --git a/xla/service/BUILD b/xla/service/BUILD index f154fbf2e98f17..5e302043a70f7c 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -6194,6 +6194,29 @@ xla_cc_test( ], ) +cc_library( + name = "unique_channel_id_enforcer", + srcs = ["unique_channel_id_enforcer.cc"], + hdrs = ["unique_channel_id_enforcer.h"], + deps = [ + ":hlo_pass", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "@com_google_absl//absl/status:statusor", + ], +) + +xla_cc_test( + name = "unique_channel_id_enforcer_test", + srcs = ["unique_channel_id_enforcer_test.cc"], + deps = [ + ":hlo_parser", + ":unique_channel_id_enforcer", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "root_instruction_sinker", srcs = ["root_instruction_sinker.cc"], diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 6c9b9f34b2ffd6..dfff7d893fbf91 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3262,6 +3262,7 @@ cc_library( "//xla/service:topk_rewriter", "//xla/service:transpose_folding", "//xla/service:tuple_simplifier", + "//xla/service:unique_channel_id_enforcer", "//xla/service:while_loop_all_reduce_code_motion", "//xla/service:while_loop_constant_sinking", "//xla/service:while_loop_simplifier", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 550476b74201bf..802f21af0bf3a4 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -220,6 +220,7 @@ limitations under the License. #include "xla/service/topk_rewriter.h" #include "xla/service/transpose_folding.h" #include "xla/service/tuple_simplifier.h" +#include "xla/service/unique_channel_id_enforcer.h" #include "xla/service/while_loop_all_reduce_code_motion.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/service/while_loop_simplifier.h" @@ -2347,6 +2348,7 @@ absl::Status GpuCompiler::RunPreSchedulingPasses( HloModule* module, se::StreamExecutor* stream_exec) { HloPassPipeline pipeline("pre-scheduling-passes"); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(module).status(); } diff --git a/xla/service/unique_channel_id_enforcer.cc b/xla/service/unique_channel_id_enforcer.cc new file mode 100644 index 00000000000000..4762961fa07ea0 --- /dev/null +++ b/xla/service/unique_channel_id_enforcer.cc @@ -0,0 +1,56 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/unique_channel_id_enforcer.h" + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/utils/hlo_query.h" + +namespace xla { + +absl::StatusOr UniqueChannelIdEnforcer::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + absl::flat_hash_set> used_channel_ids; + auto next_channel_id = hlo_query::NextChannelId(*module); + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (!hlo_query::IsCollectiveCommunicationOp(instruction->opcode())) + continue; + auto channel_id = instruction->channel_id(); + if (used_channel_ids.contains(channel_id)) { + if (assert_unique_channel_ids_) { + LOG(ERROR) << "Duplicate channel ID " << channel_id.value_or(-1) + << " found on instruction: " << instruction->ToString(); + return absl::InternalError(absl::StrFormat( + "Duplicate channel ID %d found on instruction: %s", + channel_id.value_or(-1), instruction->ToString())); + } + instruction->set_channel_id(next_channel_id); + used_channel_ids.insert(next_channel_id); + next_channel_id++; + changed = true; + } else { + used_channel_ids.insert(channel_id); + } + } + } + + return changed; +} + +} // namespace xla diff --git a/xla/service/unique_channel_id_enforcer.h b/xla/service/unique_channel_id_enforcer.h new file mode 100644 index 00000000000000..e64d49a40858c9 --- /dev/null +++ b/xla/service/unique_channel_id_enforcer.h @@ -0,0 +1,43 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_UNIQUE_CHANNEL_ID_ENFORCER_H_ +#define XLA_SERVICE_UNIQUE_CHANNEL_ID_ENFORCER_H_ + +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +// A pass which enforces that every collective +// must have a unique channel id. +class UniqueChannelIdEnforcer : public HloModulePass { + public: + explicit UniqueChannelIdEnforcer(bool assert_unique_channel_ids = false) + : assert_unique_channel_ids_(assert_unique_channel_ids) {} + + absl::string_view name() const override { + return "unique-channel-id-enforcer"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + bool assert_unique_channel_ids_; +}; + +} // namespace xla + +#endif // XLA_SERVICE_UNIQUE_CHANNEL_ID_ENFORCER_H_ diff --git a/xla/service/unique_channel_id_enforcer_test.cc b/xla/service/unique_channel_id_enforcer_test.cc new file mode 100644 index 00000000000000..ff2ae49b8fcc24 --- /dev/null +++ b/xla/service/unique_channel_id_enforcer_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/unique_channel_id_enforcer.h" + +#include "xla/service/hlo_parser.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +using UniqueChannelIdEnforcerTest = HloTestBase; + +TEST_F(UniqueChannelIdEnforcerTest, EnsureUniqueChannelIdsAllGather) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + param0 = f32[8] parameter(0) + param1 = f32[8] parameter(1) + allgather0 = f32[32] all-gather(param0), channel_id=1, replica_groups={}, dimensions={0} + allgather1 = f32[32] all-gather(param1), channel_id=1, replica_groups={}, dimensions={0} + ROOT tuple = (f32[32], f32[32]) tuple(allgather0, allgather1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + UniqueChannelIdEnforcer enforcer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, enforcer.Run(module.get())); + EXPECT_TRUE(changed); + + // Verify that channel IDs are unique for all-gather ops + std::optional all_gather1_channel_id; + std::optional all_gather2_channel_id; + + for (HloInstruction* inst : module->entry_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kAllGather) { + if (!all_gather1_channel_id.has_value()) { + all_gather1_channel_id = inst->channel_id(); + } else { + all_gather2_channel_id = inst->channel_id(); + } + } + } + + ASSERT_TRUE(all_gather1_channel_id.has_value()); + ASSERT_TRUE(all_gather2_channel_id.has_value()); + EXPECT_NE(all_gather1_channel_id.value(), all_gather2_channel_id.value()); +} + +TEST_F(UniqueChannelIdEnforcerTest, ChannelIdsAlreadyUnique) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + param0 = f32[8] parameter(0) + param1 = f32[8] parameter(1) + allgather0 = f32[32] all-gather(param0), channel_id=1, replica_groups={}, dimensions={0} + allgather1 = f32[32] all-gather(param1), channel_id=2, replica_groups={}, dimensions={0} + ROOT tuple = (f32[32], f32[32]) tuple(allgather0, allgather1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + UniqueChannelIdEnforcer enforcer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, enforcer.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(UniqueChannelIdEnforcerTest, DuplicateChannelIdsAssertTrue) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY entry { + param0 = f32[8] parameter(0) + param1 = f32[8] parameter(1) + allgather0 = f32[32] all-gather(param0), channel_id=1, replica_groups={}, dimensions={0} + allgather1 = f32[32] all-gather(param1), channel_id=1, replica_groups={}, dimensions={0} + ROOT tuple = (f32[32], f32[32]) tuple(allgather0, allgather1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + UniqueChannelIdEnforcer enforcer(/*assert_unique_channel_ids=*/true); + auto status_or_changed = enforcer.Run(module.get()); + + EXPECT_FALSE(status_or_changed.ok()); +} + +} // namespace +} // namespace xla From 4447e8de16ed160da8867065a5ba2548ab1b88a0 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Fri, 19 Jul 2024 04:51:29 -0700 Subject: [PATCH 014/376] Verify that wgmma is used for memory bound shapes PiperOrigin-RevId: 653968803 --- ...riton_fusion_emitter_device_legacy_test.cc | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 9701fbe37f5e76..3eec1c4e934946 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -304,7 +304,6 @@ CHECK: } )")); } - TEST_F(TritonTest, PredParametersAreTruncatedToI1) { const std::string kHloText = R"( HloModule m @@ -4774,6 +4773,37 @@ CHECK: wgmma )"); } +TEST_F(TritonGemmTest, WgmmaIsUsedForMemBoundShape) { + if (GetCudaComputeCapability().major != se::CudaComputeCapability::HOPPER) { + GTEST_SKIP() << "wgmma instruction is only available on Hopper"; + } + const std::string hlo_text = R"( +gemm_fusion_dot { + p0 = s8[128,128]{1,0} parameter(0) + p1 = bf16[128,16]{1,0} parameter(1) + convert = bf16[128,128]{1,0} convert(p0) + ROOT %dot = bf16[128,16]{1,0} dot(convert, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = s8[128,128]{1,0} parameter(0) + p1 = bf16[128,16]{1,0} parameter(1) + ROOT triton_gemm_fusion_dot = bf16[128,16]{1,0} fusion(p0, p1), kind=kCustom, + calls=gemm_fusion_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":128,"block_n":16,"block_k":16, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndOptionallyVerifyPtx(std::move(verified_module), R"( +CHECK: wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 +)"); +} + // Test presence of default matmul config information // when gemm autotuner is not present in pipeline, // (which is currently the case on rocm). From b5ee0103cc35d5b328548643a9c86c4568ab18d5 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 19 Jul 2024 05:35:48 -0700 Subject: [PATCH 015/376] Update file visibility in BUILD files. PiperOrigin-RevId: 653978189 --- third_party/shardy/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/shardy/BUILD b/third_party/shardy/BUILD index 3b946e563d4e30..ea1ecdb548c1f4 100644 --- a/third_party/shardy/BUILD +++ b/third_party/shardy/BUILD @@ -1,3 +1,5 @@ # Necessary for bazel to recognize this as a package. # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + +exports_files(srcs = ["workspace.bzl"]) From e589b4a985d44ab4f642b6adc978c269129f3ab1 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Fri, 19 Jul 2024 05:57:21 -0700 Subject: [PATCH 016/376] [XLA:GPU] Pass operand to the construction of TiledHloInstruction. We use `tile_offset_indexing` to distinguish between tiles of the same instruction. Composing and simplifying indexing maps is expensive. For instruction inside the fusion that are not load/store, comparing `operand` pointers is a cheaper way to achieve the same effect. This change is a preparation to compute `tile_offset_indexing` only when necessary. PiperOrigin-RevId: 653982200 --- .../gpu/model/symbolic_tile_analysis.cc | 23 ++++++----- .../gpu/model/tiled_hlo_instruction.cc | 13 ++++--- xla/service/gpu/model/tiled_hlo_instruction.h | 39 +++++++++++-------- .../gpu/model/tiled_hlo_instruction_test.cc | 18 ++++----- 4 files changed, 50 insertions(+), 43 deletions(-) diff --git a/xla/service/gpu/model/symbolic_tile_analysis.cc b/xla/service/gpu/model/symbolic_tile_analysis.cc index 355d2bcaf1c43a..d72590debe3d2e 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -458,23 +458,22 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( *symbolic_tiled_hlo, output_tiling_info.output_tile_offset_indexing, context_)); - TF_ASSIGN_OR_RETURN( - std::unique_ptr tiled_hlo_holder, - TiledHloInstruction::Create( - symbolic_tiled_hlo->hlo(), std::move(tile_sizes), - std::move(tile_strides), std::move(tile_offset_indexing))); + llvm::SmallVector operands; + for (const SymbolicTiledHloInstruction* operand : + symbolic_tiled_hlo->operands()) { + operands.push_back(symbolic_to_tiled_hlo_map.at(operand)); + } + + TF_ASSIGN_OR_RETURN(std::unique_ptr tiled_hlo_holder, + TiledHloInstruction::Create( + symbolic_tiled_hlo->hlo(), std::move(operands), + std::move(tile_sizes), std::move(tile_strides), + std::move(tile_offset_indexing))); auto [tiled_hlo, inserted] = tiled_hlo_instructions_set.Insert(std::move(tiled_hlo_holder)); symbolic_to_tiled_hlo_map[symbolic_tiled_hlo.get()] = tiled_hlo; - - if (inserted) { - for (const SymbolicTiledHloInstruction* operand : - symbolic_tiled_hlo->operands()) { - tiled_hlo->AppendOperand(symbolic_to_tiled_hlo_map.at(operand)); - } - } } return TiledHloComputation::FromSortedTiledHloInstructions( tiled_hlo_instructions_set.ExtractData(), diff --git a/xla/service/gpu/model/tiled_hlo_instruction.cc b/xla/service/gpu/model/tiled_hlo_instruction.cc index 59a145369f7d01..4a2543970a1df1 100644 --- a/xla/service/gpu/model/tiled_hlo_instruction.cc +++ b/xla/service/gpu/model/tiled_hlo_instruction.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/memory/memory.h" #include "absl/status/status.h" @@ -33,10 +34,12 @@ namespace gpu { /*static*/ absl::StatusOr> -TiledHloInstruction::Create(const HloInstruction* hlo, - llvm::SmallVector tile_sizes, - llvm::SmallVector tile_strides, - IndexingMap tile_offsets_indexing) { +TiledHloInstruction::Create( + const HloInstruction* hlo, + llvm::SmallVector operands, + llvm::SmallVector tile_sizes, + llvm::SmallVector tile_strides, + IndexingMap tile_offsets_indexing) { int rank = hlo->shape().rank(); if (tile_sizes.size() != rank) { @@ -61,7 +64,7 @@ TiledHloInstruction::Create(const HloInstruction* hlo, } return absl::WrapUnique(new TiledHloInstruction( - hlo, std::move(tile_sizes), std::move(tile_strides), + hlo, std::move(operands), std::move(tile_sizes), std::move(tile_strides), std::move(tile_offsets_indexing))); } diff --git a/xla/service/gpu/model/tiled_hlo_instruction.h b/xla/service/gpu/model/tiled_hlo_instruction.h index 5978df377fc033..86c7969d06560e 100644 --- a/xla/service/gpu/model/tiled_hlo_instruction.h +++ b/xla/service/gpu/model/tiled_hlo_instruction.h @@ -48,13 +48,25 @@ class TiledHloInstruction { // * `tile_offsets_indexing` should have the number of dimensions equal to the // rank of the output tile and 0 symbols. static absl::StatusOr> Create( - const HloInstruction* hlo, llvm::SmallVector tile_sizes, + const HloInstruction* hlo, + llvm::SmallVector operands, + llvm::SmallVector tile_sizes, llvm::SmallVector tile_strides, + IndexingMap tile_offsets_indexing); // Returns the original HLO instruction. const HloInstruction* hlo() const { return hlo_; } + // Operands of the instruction in the tiled computation graph. + const TiledHloInstruction* operand(int64_t operand_id) const { + return operands_[operand_id]; + } + + const llvm::SmallVector& operands() const { + return operands_; + } + // Returns the tile sizes. The number of tile sizes is equal to the rank of // the output shape. const llvm::SmallVector& tile_sizes() const { return tile_sizes_; } @@ -73,18 +85,6 @@ class TiledHloInstruction { return tile_offsets_indexing_; } - const TiledHloInstruction* operand(int64_t operand_id) const { - return operands_[operand_id]; - } - - const std::vector& operands() const { - return operands_; - } - - void AppendOperand(TiledHloInstruction* operand) { - operands_.push_back(operand); - } - std::string ToString() const; // This allows GUnit to print TiledHloInstruction. @@ -95,10 +95,12 @@ class TiledHloInstruction { private: TiledHloInstruction(const HloInstruction* hlo, + llvm::SmallVector operands, llvm::SmallVector tile_sizes, llvm::SmallVector tile_strides, IndexingMap tile_offsets_indexing) : hlo_(hlo), + operands_(std::move(operands)), tile_sizes_(std::move(tile_sizes)), tile_strides_(std::move(tile_strides)), tile_offsets_indexing_(std::move(tile_offsets_indexing)) {} @@ -106,21 +108,22 @@ class TiledHloInstruction { // Pointer to the original HLO instruction. const HloInstruction* hlo_; + // Operands of the instruction in the tiled computation graph. + llvm::SmallVector operands_; + // Tile sizes and strides. llvm::SmallVector tile_sizes_; llvm::SmallVector tile_strides_; // Indexing map for tile offsets. IndexingMap tile_offsets_indexing_; - - // Operands of the instruction in the tiled computation graph. - std::vector operands_; }; inline bool operator==(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs) { return lhs.hlo() == rhs.hlo() && lhs.tile_sizes() == rhs.tile_sizes() && lhs.tile_strides() == rhs.tile_strides() && + lhs.operands() == rhs.operands() && lhs.tile_offsets_indexing() == rhs.tile_offsets_indexing(); } @@ -133,11 +136,13 @@ template H AbslHashValue(H h, const TiledHloInstruction& tiled_hlo_instruction) { // There is no default hash implementation for llvm::SmallVector neither in // AbslHashValue nor in llvm::hash_value. We can use the available hash - // implementation for absl::Span instread. + // implementation for absl::Span instead. return H::combine( std::move(h), tiled_hlo_instruction.hlo(), absl::Span(tiled_hlo_instruction.tile_sizes()), absl::Span(tiled_hlo_instruction.tile_strides()), + absl::Span( + tiled_hlo_instruction.operands()), tiled_hlo_instruction.tile_offsets_indexing()); } diff --git a/xla/service/gpu/model/tiled_hlo_instruction_test.cc b/xla/service/gpu/model/tiled_hlo_instruction_test.cc index d9188b36f3a5ec..d49a666a89e462 100644 --- a/xla/service/gpu/model/tiled_hlo_instruction_test.cc +++ b/xla/service/gpu/model/tiled_hlo_instruction_test.cc @@ -49,16 +49,16 @@ TEST_F(TiledHloInstructionTest, TileSizesAndStridesShouldMatchHloShapeRank) { /*dim_upper_bounds=*/{8}, /*symbol_upper_bounds=*/{}); - EXPECT_THAT(TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16}, - /*tile_strides=*/{1, 1}, - block_id_to_tile_offsets_indexing) + EXPECT_THAT(TiledHloInstruction::Create( + hlo.get(), /*operands=*/{}, /*tile_sizes=*/{16}, + /*tile_strides=*/{1, 1}, block_id_to_tile_offsets_indexing) .status() .message(), HasSubstr("Number of tile sizes must be equal to the rank")); - EXPECT_THAT(TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16, 16}, - /*tile_strides=*/{1, 1, 1}, - block_id_to_tile_offsets_indexing) + EXPECT_THAT(TiledHloInstruction::Create( + hlo.get(), /*operands=*/{}, /*tile_sizes=*/{16, 16}, + /*tile_strides=*/{1, 1, 1}, block_id_to_tile_offsets_indexing) .status() .message(), HasSubstr("Number of tile strides must be equal to the rank")); @@ -76,9 +76,9 @@ TEST_F(TiledHloInstructionTest, /*symbol_upper_bounds=*/{}); EXPECT_THAT( - TiledHloInstruction::Create(hlo.get(), /*tile_sizes=*/{16, 16}, - /*tile_strides=*/{1, 1}, - tile_offsets_indexing1) + TiledHloInstruction::Create( + hlo.get(), /*operands=*/{}, /*tile_sizes=*/{16, 16}, + /*tile_strides=*/{1, 1}, tile_offsets_indexing1) .status() .message(), HasSubstr( From 97d2bb8710f93e4fa73aaa565808910692722f9a Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Fri, 19 Jul 2024 06:08:46 -0700 Subject: [PATCH 017/376] PR #15050: Quantized Collectives Imported from GitHub PR https://github.com/openxla/xla/pull/15050 Introduces a pass that can reduce the amount of data transferred in all-gather, all-to-all, collective-broadcast and collective-permute ops by exchanging the collective with a subsequent quantization or conversion to a narrower type. Copybara import of the project: -- 658710e6d4b518fc2efc860279ef31cf1cf9d8ea by Philipp Hack : Adds a pass that exchanges collectives with subsequent quantizations or narrowing type conversions. Merging this change closes #15050 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15050 from philipphack:u_collective_quant_xla 658710e6d4b518fc2efc860279ef31cf1cf9d8ea PiperOrigin-RevId: 653984852 --- xla/service/BUILD | 33 +++ xla/service/collective_quantizer.cc | 191 +++++++++++++++++ xla/service/collective_quantizer.h | 49 +++++ xla/service/collective_quantizer_test.cc | 260 +++++++++++++++++++++++ xla/service/gpu/BUILD | 1 + xla/service/gpu/gpu_compiler.cc | 5 + xla/tests/collective_ops_e2e_test.cc | 29 +++ 7 files changed, 568 insertions(+) create mode 100644 xla/service/collective_quantizer.cc create mode 100644 xla/service/collective_quantizer.h create mode 100644 xla/service/collective_quantizer_test.cc diff --git a/xla/service/BUILD b/xla/service/BUILD index 5e302043a70f7c..38bb3ff4cc74b9 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -640,6 +640,39 @@ xla_cc_test( ], ) +cc_library( + name = "collective_quantizer", + srcs = ["collective_quantizer.cc"], + hdrs = ["collective_quantizer.h"], + deps = [ + ":hlo_pass", + ":pattern_matcher", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "collective_quantizer_test", + srcs = ["collective_quantizer_test.cc"], + deps = [ + ":collective_quantizer", + ":hlo_verifier", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", + ], +) + cc_library( name = "dump", srcs = ["dump.cc"], diff --git a/xla/service/collective_quantizer.cc b/xla/service/collective_quantizer.cc new file mode 100644 index 00000000000000..3a6ff8b76336c2 --- /dev/null +++ b/xla/service/collective_quantizer.cc @@ -0,0 +1,191 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/collective_quantizer.h" + +#include "xla/service/pattern_matcher.h" +#include "xla/shape_util.h" + +namespace xla { +namespace { + +namespace m = match; + +// Matches a broadcast of a scalar operand. +template +auto ScalarBroadcast(Args... args) { + return m::Broadcast(args...).WithPredicate([](const HloInstruction* instr) { + return ShapeUtil::IsScalar(instr->operand(0)->shape()); + }); +} + +// Matches a bitcast that preserves the element type of the operand. +auto BitcastPreservesElementType() { + return m::Bitcast().WithPredicate([](const HloInstruction* instr) { + return ShapeUtil::SameElementType(instr->shape(), + instr->operand(0)->shape()); + }); +} + +// Matches a type conversion to a type with a smaller byte size than that of the +// operand. +auto ConvertToNarrowerType() { + auto converts_to_narrower_type = [](const HloInstruction* instr) -> bool { + return ShapeUtil::ByteSizeOfPrimitiveType(instr->shape().element_type()) < + ShapeUtil::ByteSizeOfPrimitiveType( + instr->operand(0)->shape().element_type()); + }; + return m::Convert().WithPredicate(converts_to_narrower_type); +} + +// Returns true iff instr describes a quantization, i.e. a multiplication or +// division by a broadcasted scalar followed by a clamp and a type conversion, +// or a plain type conversion to a narrower type. Unary bitcast, copy, reshape +// or slice ops with one user may precede the quantization or type conversion. +bool IsSupportedQuantization(HloInstruction* instr, HloInstruction** convert, + HloInstruction** binary, HloInstruction** clamp, + HloInstruction** scale_bcast, + std::vector& unary_ops) { + std::vector ops; + while (instr->user_count() <= 1) { + if (Match(instr, m::AnyOf( + BitcastPreservesElementType(), m::Copy(), m::Reshape(), + m::Slice(), m::Multiply(), m::Divide(), m::Clamp()))) { + if (instr->user_count() > 0) { + ops.emplace_back(instr); + instr = instr->users()[0]; + continue; + } + break; + } + + if (Match(instr, ConvertToNarrowerType())) { + ops.emplace_back(instr); + break; + } + VLOG(5) << "Unsupported instruction."; + return false; + } + + // In the quantization case, the type conversion is preceded by a + // multiplication or division by a broadcasted scalar and a clamp instruction. + if (ops.size() > 2 && + (Match(ops.back(), + m::Convert(convert, m::Clamp(clamp, ScalarBroadcast(), + m::MultiplyAnyOrder( + binary, m::Op(), + ScalarBroadcast(scale_bcast)), + ScalarBroadcast()))) || + Match( + ops.back(), + m::Convert(convert, m::Clamp(clamp, ScalarBroadcast(), + m::Divide(binary, m::Op(), + ScalarBroadcast(scale_bcast)), + ScalarBroadcast()))))) { + unary_ops = {ops.begin(), ops.end() - 3}; + } else if (ops.size() > 0 && Match(ops.back(), m::Convert(convert))) { + unary_ops = {ops.begin(), ops.end() - 1}; + } else { + VLOG(5) << "Did not find type conversion or quantization pattern."; + return false; + } + + // The collected unary ops between collective and quantization/type conversion + // may only include bitcast, copy, reshape and slice instructions. + for (HloInstruction* unary_op : unary_ops) { + if (!Match(unary_op, m::AnyOf(m::Bitcast(), m::Copy(), + m::Reshape(), m::Slice()))) { + VLOG(5) << "Unexpected instruction in unary ops."; + return false; + } + } + return true; +} + +bool IsSupportedCollective(HloInstruction* instr) { + return instr->opcode() == HloOpcode::kAllGather || + instr->opcode() == HloOpcode::kAllToAll || + instr->opcode() == HloOpcode::kCollectiveBroadcast || + instr->opcode() == HloOpcode::kCollectivePermute; +} + +} // namespace + +absl::StatusOr CollectiveQuantizer::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + + for (HloComputation* comp : module->MakeComputationPostOrder()) { + for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { + HloInstruction *binary = nullptr, *clamp, *convert, *scale_bcast; + std::vector unary_ops; + if (instr->user_count() == 1 && IsSupportedCollective(instr) && + IsSupportedQuantization(instr->users()[0], &convert, &binary, &clamp, + &scale_bcast, unary_ops)) { + HloInstruction* coll_operand = instr->mutable_operand(0); + HloInstruction *new_binary, *new_clamp; + // When there is a quantization, insert the scale and clamp ops. + if (binary) { + HloInstruction* new_scale_bcast = comp->AddInstruction( + scale_bcast->CloneWithNewShape(coll_operand->shape())); + new_binary = comp->AddInstruction(binary->CloneWithNewOperands( + coll_operand->shape(), {coll_operand, new_scale_bcast})); + HloInstruction* new_clamp_lower = comp->AddInstruction( + clamp->operand(0)->CloneWithNewShape(coll_operand->shape())); + HloInstruction* new_clamp_upper = comp->AddInstruction( + clamp->operand(2)->CloneWithNewShape(coll_operand->shape())); + new_clamp = comp->AddInstruction(clamp->CloneWithNewOperands( + coll_operand->shape(), + {new_clamp_lower, new_binary, new_clamp_upper})); + } + + // Move the collective past the conversion to the narrow type. + Shape new_convert_shape = ShapeUtil::ChangeElementType( + instr->operand(0)->shape(), convert->shape().element_type()); + HloInstruction* new_convert = + comp->AddInstruction(convert->CloneWithNewOperands( + new_convert_shape, {binary ? new_clamp : coll_operand})); + Shape new_collective_shape = ShapeUtil::ChangeElementType( + instr->shape(), convert->shape().element_type()); + HloInstruction* new_collective = comp->AddInstruction( + instr->CloneWithNewOperands(new_collective_shape, {new_convert})); + + // Sequentially apply the collected unary ops to the output of the + // quantized collective. + auto shift_unary_ops = [comp, &unary_ops](HloInstruction** x) -> void { + for (HloInstruction* unary_op : unary_ops) { + *x = comp->AddInstruction(unary_op->CloneWithNewOperands( + ShapeUtil::MakeShapeWithDenseLayout( + (*x)->shape().element_type(), + unary_op->shape().dimensions(), + unary_op->shape().layout().minor_to_major()), + {*x})); + } + }; + + shift_unary_ops(&new_collective); + TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(new_collective)); + + changed = true; + VLOG(5) << "Quantized collective " << new_collective->ToShortString(); + } + } + } + + return changed; +} + +} // namespace xla diff --git a/xla/service/collective_quantizer.h b/xla/service/collective_quantizer.h new file mode 100644 index 00000000000000..642a8c0f238d83 --- /dev/null +++ b/xla/service/collective_quantizer.h @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_COLLECTIVE_QUANTIZER_H_ +#define XLA_SERVICE_COLLECTIVE_QUANTIZER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +// Reduces the amount of data transferred in all-gather, all-to-all, +// collective-broadcast and collective-permute ops by exchanging the collectives +// with subsequent quantizations or type conversions to a narrower type. When +// present, unary ops such as bitcasts, copies, reshapes and slices between +// collective and quantization/type conversion are shifted, i.e. transforms +// +// collective --> unary --> quantization/type conversion +// +// into +// +// quantization/type conversion --> collective --> unary. +class CollectiveQuantizer : public HloModulePass { + public: + absl::string_view name() const override { return "collective-quantizer"; } + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_COLLECTIVE_QUANTIZER_H_ diff --git a/xla/service/collective_quantizer_test.cc b/xla/service/collective_quantizer_test.cc new file mode 100644 index 00000000000000..a095e3ef4e19a1 --- /dev/null +++ b/xla/service/collective_quantizer_test.cc @@ -0,0 +1,260 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/collective_quantizer.h" + +#include + +#include +#include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/service/hlo_verifier.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class CollectiveQuantizerTest : public HloTestBase { + public: + absl::StatusOr RunCollectiveQuantizer(HloModule* module) { + CollectiveQuantizer collective_quantizer; + return collective_quantizer.Run(module, {}); + } +}; + +TEST_F(CollectiveQuantizerTest, AllGatherConvert) { + absl::string_view hlo_string = R"( + HloModule module + ENTRY entry { + param = bf16[8,4,8,128] parameter(0) + all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + ROOT convert = f8e4m3fn[8,32,8,128] convert(all-gather) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunCollectiveQuantizer(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllGather(op::Convert(op::Parameter()))); + HloInstruction* all_gather = module->entry_computation()->root_instruction(); + EXPECT_THAT(all_gather->shape().element_type(), F8E4M3FN); +} + +TEST_F(CollectiveQuantizerTest, AllGatherConvertUnary) { + absl::string_view hlo_string = R"( + HloModule module + ENTRY entry { + param = bf16[8,4,8,128] parameter(0) + all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + reshape = bf16[8,32,1024] reshape(all-gather) + slice = bf16[8,32,512] slice(reshape), slice={[0:8], [0:32], [256:768]} + ROOT convert = f8e4m3fn[8,32,512] convert(slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunCollectiveQuantizer(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Slice(op::Reshape(op::AllGather(op::Convert(op::Parameter()))))); + HloInstruction* all_gather = module->entry_computation()->root_instruction(); + EXPECT_THAT(all_gather->shape().element_type(), F8E4M3FN); +} + +TEST_F(CollectiveQuantizerTest, AllGatherQuantize) { + absl::string_view hlo_string = R"( + HloModule module + ENTRY entry { + param = bf16[8,4,8,128] parameter(0) + all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + scale = bf16[] parameter(1) + scale_bcast = bf16[8,32,8,128] broadcast(scale), dimensions={} + divide = bf16[8,32,8,128] divide(all-gather, scale_bcast) + clamp_lower = bf16[] constant(-448.0) + clamp_lower_bcast = bf16[8,32,8,128] broadcast(clamp_lower), dimensions={} + clamp_upper = bf16[] constant(448.0) + clamp_upper_bcast = bf16[8,32,8,128] broadcast(clamp_upper), dimensions={} + clamp = bf16[8,32,8,128] clamp(clamp_lower_bcast, divide, clamp_upper_bcast) + ROOT convert = f8e4m3fn[8,32,8,128] convert(clamp) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunCollectiveQuantizer(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllGather(op::Convert(op::Clamp( + op::Broadcast(), op::Divide(op::Parameter(), op::Broadcast()), + op::Broadcast())))); + HloInstruction* all_gather = module->entry_computation()->root_instruction(); + EXPECT_THAT(all_gather->shape().element_type(), F8E4M3FN); +} + +TEST_F(CollectiveQuantizerTest, AllToAllQuantize) { + absl::string_view hlo_string = R"( + HloModule module + ENTRY entry { + param = bf16[8,32,8,128] parameter(0) + all-to-all = bf16[8,32,8,128] all-to-all(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + scale = bf16[] parameter(1) + scale_bcast = bf16[8,32,8,128] broadcast(scale), dimensions={} + divide = bf16[8,32,8,128] divide(all-to-all, scale_bcast) + clamp_lower = bf16[] constant(-448.0) + clamp_lower_bcast = bf16[8,32,8,128] broadcast(clamp_lower), dimensions={} + clamp_upper = bf16[] constant(448.0) + clamp_upper_bcast = bf16[8,32,8,128] broadcast(clamp_upper), dimensions={} + clamp = bf16[8,32,8,128] clamp(clamp_lower_bcast, divide, clamp_upper_bcast) + ROOT convert = f8e4m3fn[8,32,8,128] convert(clamp) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunCollectiveQuantizer(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllToAll(op::Convert(op::Clamp( + op::Broadcast(), op::Divide(op::Parameter(), op::Broadcast()), + op::Broadcast())))); + HloInstruction* all_to_all = module->entry_computation()->root_instruction(); + EXPECT_THAT(all_to_all->shape().element_type(), F8E4M3FN); +} + +TEST_F(CollectiveQuantizerTest, CollectiveBroadcastQuantize) { + absl::string_view hlo_string = R"( + HloModule module + ENTRY entry { + param = bf16[8,32,8,128] parameter(0) + collective-broadcast = bf16[8,32,8,128] collective-broadcast(param), replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + scale = bf16[] parameter(1) + scale_bcast = bf16[8,32,8,128] broadcast(scale), dimensions={} + divide = bf16[8,32,8,128] divide(collective-broadcast, scale_bcast) + clamp_lower = bf16[] constant(-448.0) + clamp_lower_bcast = bf16[8,32,8,128] broadcast(clamp_lower), dimensions={} + clamp_upper = bf16[] constant(448.0) + clamp_upper_bcast = bf16[8,32,8,128] broadcast(clamp_upper), dimensions={} + clamp = bf16[8,32,8,128] clamp(clamp_lower_bcast, divide, clamp_upper_bcast) + ROOT convert = f8e4m3fn[8,32,8,128] convert(clamp) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunCollectiveQuantizer(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::CollectiveBroadcast(op::Convert(op::Clamp( + op::Broadcast(), op::Divide(op::Parameter(), op::Broadcast()), + op::Broadcast())))); + HloInstruction* collective_broadcast = + module->entry_computation()->root_instruction(); + EXPECT_THAT(collective_broadcast->shape().element_type(), F8E4M3FN); +} + +TEST_F(CollectiveQuantizerTest, CollectivePermuteQuantize) { + absl::string_view hlo_string = R"( + HloModule module + ENTRY entry { + param = bf16[8,32,8,128] parameter(0) + collective-permute = bf16[8,32,8,128] collective-permute(param), source_target_pairs={{0,1},{2,3},{4,5},{6,7}}, channel_id=1 + scale = bf16[] parameter(1) + scale_bcast = bf16[8,32,8,128] broadcast(scale), dimensions={} + divide = bf16[8,32,8,128] divide(collective-permute, scale_bcast) + clamp_lower = bf16[] constant(-448.0) + clamp_lower_bcast = bf16[8,32,8,128] broadcast(clamp_lower), dimensions={} + clamp_upper = bf16[] constant(448.0) + clamp_upper_bcast = bf16[8,32,8,128] broadcast(clamp_upper), dimensions={} + clamp = bf16[8,32,8,128] clamp(clamp_lower_bcast, divide, clamp_upper_bcast) + ROOT convert = f8e4m3fn[8,32,8,128] convert(clamp) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunCollectiveQuantizer(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::CollectivePermute(op::Convert(op::Clamp( + op::Broadcast(), op::Divide(op::Parameter(), op::Broadcast()), + op::Broadcast())))); + HloInstruction* collective_permute = + module->entry_computation()->root_instruction(); + EXPECT_THAT(collective_permute->shape().element_type(), F8E4M3FN); +} + +TEST_F(CollectiveQuantizerTest, AllGatherQuantizeUnary) { + absl::string_view hlo_string = R"( + HloModule module + ENTRY entry { + param = bf16[8,4,8,128] parameter(0) + all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + reshape = bf16[8,32,1024] reshape(all-gather) + slice = bf16[8,32,512] slice(reshape), slice={[0:8], [0:32], [256:768]} + scale = bf16[] parameter(1) + scale_bcast = bf16[8,32,512] broadcast(scale), dimensions={} + divide = bf16[8,32,512] divide(slice, scale_bcast) + clamp_lower = bf16[] constant(-448.0) + clamp_lower_bcast = bf16[8,32,512] broadcast(clamp_lower), dimensions={} + clamp_upper = bf16[] constant(448.0) + clamp_upper_bcast = bf16[8,32,512] broadcast(clamp_upper), dimensions={} + clamp = bf16[8,32,512] clamp(clamp_lower_bcast, divide, clamp_upper_bcast) + ROOT convert = f8e4m3fn[8,32,512] convert(clamp) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunCollectiveQuantizer(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Slice(op::Reshape(op::AllGather(op::Convert(op::Clamp( + op::Broadcast(), op::Divide(op::Parameter(), op::Broadcast()), + op::Broadcast())))))); + HloInstruction* slice = module->entry_computation()->root_instruction(); + EXPECT_THAT(slice->shape().element_type(), F8E4M3FN); +} + +TEST_F(CollectiveQuantizerTest, AllGatherQuantizeMultiUser) { + absl::string_view hlo_string = R"( + HloModule module + ENTRY entry { + param = bf16[8,4,8,128] parameter(0) + all-gather = bf16[8,32,8,128] all-gather(param), dimensions={1}, replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1 + scale = bf16[] parameter(1) + scale_bcast = bf16[8,32,8,128] broadcast(scale), dimensions={} + divide = bf16[8,32,8,128] divide(all-gather, scale_bcast) + clamp_lower = bf16[] constant(-448.0) + clamp_lower_bcast = bf16[8,32,8,128] broadcast(clamp_lower), dimensions={} + clamp_upper = bf16[] constant(448.0) + clamp_upper_bcast = bf16[8,32,8,128] broadcast(clamp_upper), dimensions={} + clamp = bf16[8,32,8,128] clamp(clamp_lower_bcast, divide, clamp_upper_bcast) + add = bf16[8,32,8,128] add(divide, clamp) + ROOT convert = f8e4m3fn[8,32,8,128] convert(add) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunCollectiveQuantizer(module.get())); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index dfff7d893fbf91..3eee391f579cc7 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3186,6 +3186,7 @@ cc_library( "//xla/service:call_inliner", "//xla/service:collective_permute_decomposer", "//xla/service:collective_pipeliner", + "//xla/service:collective_quantizer", "//xla/service:collectives_schedule_linearizer", "//xla/service:comparison_expander", "//xla/service:compiler", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 802f21af0bf3a4..44aa2e8474cc51 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -83,6 +83,7 @@ limitations under the License. #include "xla/service/call_inliner.h" #include "xla/service/collective_permute_decomposer.h" #include "xla/service/collective_pipeliner.h" +#include "xla/service/collective_quantizer.h" #include "xla/service/collectives_schedule_linearizer.h" #include "xla/service/comparison_expander.h" #include "xla/service/compiler.h" @@ -934,6 +935,10 @@ absl::Status RunCollectiveOptimizationPasses( // Remove dead computations left over after ar/rs promotion. collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + // Remove dead computations after collective quantization. + collectives_pipeline.AddPass(); + // Run WhileLoopTripCountAnnotator after collective pipelining and before // layout assignment and fusion.This pass does some pattern-matching on // while bodies/conditions, and this is where the HLO is "nicest". diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index 7560ecb7404ec8..f1d1c78d28bb61 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -1074,5 +1074,34 @@ ENTRY entry { EXPECT_TRUE(executable->has_module()); } +TEST_F(CollectiveOpsTestE2E, AllToAllCollectiveQuantizer) { + absl::string_view kModuleReplicatedStr = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f32[4,32,128]{2,1,0})->bf16[4,32,128]{2,1,0}}, num_partitions=4 +ENTRY entry { + param = f32[4,32,128]{2,1,0} parameter(0) + all-to-all = f32[4,32,128]{2,1,0} all-to-all(param), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1} + ROOT convert = bf16[4,32,128]{2,1,0} convert(all-to-all) +} +)"; + + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + config.set_num_partitions(kNumPartitions); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CreateExecutable(std::move(module), + /*run_hlo_passes=*/true)); + EXPECT_TRUE(executable->has_module()); + HloInstruction* all_to_all = + FindInstruction(&executable->module(), HloOpcode::kAllToAll); + EXPECT_THAT(all_to_all, NotNull()); + EXPECT_EQ(all_to_all->shape().element_type(), BF16); +} + } // namespace } // namespace xla From 76c6ed0cfbe842e124afd025006552c405d33f4f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 19 Jul 2024 08:35:17 -0700 Subject: [PATCH 018/376] Disable test in msan to fix builds. PiperOrigin-RevId: 654016962 --- xla/service/gpu/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 3eee391f579cc7..6d9480b4e72d0b 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -5571,6 +5571,8 @@ xla_cc_test( "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb", "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb", ], + # TODO(b/354186833): Re-enable msan after fix. + tags = ["nomsan"], deps = if_cuda_is_configured([ # keep sorted ":autotuner_util", From 47f2e1d597d364d8f5e81736dc3ad562a76879fb Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Fri, 19 Jul 2024 09:51:14 -0700 Subject: [PATCH 019/376] Remove unused ROCm include PiperOrigin-RevId: 654036510 --- xla/service/gpu/gemm_algorithm_picker_test.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/xla/service/gpu/gemm_algorithm_picker_test.cc b/xla/service/gpu/gemm_algorithm_picker_test.cc index 3017af4bef9528..e387aad44ef341 100644 --- a/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -36,10 +36,6 @@ limitations under the License. #include "tsl/platform/test.h" #include "tsl/protobuf/dnn.pb.h" -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif - namespace xla::gpu { namespace { From 16ff35a7de31139215803a3357cd1e4e06cc1438 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 19 Jul 2024 10:03:30 -0700 Subject: [PATCH 020/376] [XLA:GPU] Always initialize all members of `DeviceDescription`. `DeviceDescription` has two constructors and it seems neither of them is guaranteed to initialize all members. E.g. the default constructor missed the `l2_cache_size_` and `fpus_per_core_` fields. This change moves all initialization directly to the field definition. This has two benefits: - if any constructor does not explicitly initialize a field, the initializer from the definition of the field will be used. - It will be less likely to forget to initialize a field when adding new fields. PiperOrigin-RevId: 654040293 --- xla/service/gpu/BUILD | 2 - xla/stream_executor/device_description.cc | 80 +++++++---------------- xla/stream_executor/device_description.h | 72 ++++++++++---------- 3 files changed, 63 insertions(+), 91 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 6d9480b4e72d0b..3eee391f579cc7 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -5571,8 +5571,6 @@ xla_cc_test( "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb", "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb", ], - # TODO(b/354186833): Re-enable msan after fix. - tags = ["nomsan"], deps = if_cuda_is_configured([ # keep sorted ":autotuner_util", diff --git a/xla/stream_executor/device_description.cc b/xla/stream_executor/device_description.cc index 0925a0ccf8eb3b..ca19e68ffdc382 100644 --- a/xla/stream_executor/device_description.cc +++ b/xla/stream_executor/device_description.cc @@ -26,62 +26,30 @@ limitations under the License. namespace stream_executor { -static const uint64_t kUninitializedUint64 = -1ULL; -/* static */ const char *DeviceDescription::kUndefinedString = ""; - -DeviceDescription::DeviceDescription() - : device_vendor_(kUndefinedString), - platform_version_(kUndefinedString), - driver_version_(kUndefinedString), - runtime_version_(kUndefinedString), - pci_bus_id_(kUndefinedString), - name_(kUndefinedString), - model_str_(kUndefinedString), - thread_dim_limit_(kUninitializedUint64, kUninitializedUint64, - kUninitializedUint64), - block_dim_limit_(kUninitializedUint64, kUninitializedUint64, - kUninitializedUint64), - threads_per_core_limit_(kUninitializedUint64), - threads_per_block_limit_(kUninitializedUint64), - threads_per_warp_(kUninitializedUint64), - registers_per_core_limit_(kUninitializedUint64), - registers_per_block_limit_(kUninitializedUint64), - device_address_bits_(kUninitializedUint64), - device_memory_size_(kUninitializedUint64), - memory_bandwidth_(kUninitializedUint64), - shared_memory_per_core_(kUninitializedUint64), - shared_memory_per_block_(kUninitializedUint64), - clock_rate_ghz_(-1.0), - numa_node_(-1), - core_count_(-1), - ecc_enabled_(false) {} - -DeviceDescription::DeviceDescription(const GpuDeviceInfoProto &proto) { - if (proto.has_cuda_compute_capability()) { - gpu_compute_capability_ = - stream_executor::CudaComputeCapability(proto.cuda_compute_capability()); - } else { - gpu_compute_capability_ = - stream_executor::RocmComputeCapability(proto.rocm_compute_capability()); - } - threads_per_block_limit_ = proto.threads_per_block_limit(); - threads_per_warp_ = proto.threads_per_warp(); - shared_memory_per_block_ = proto.shared_memory_per_block(); - shared_memory_per_block_optin_ = proto.shared_memory_per_block_optin(); - shared_memory_per_core_ = proto.shared_memory_per_core(); - threads_per_core_limit_ = proto.threads_per_core_limit(); - core_count_ = proto.core_count(); - fpus_per_core_ = proto.fpus_per_core(); - block_dim_limit_ = - BlockDim(proto.block_dim_limit_x(), proto.block_dim_limit_y(), - proto.block_dim_limit_z()); - memory_bandwidth_ = proto.memory_bandwidth(); - l2_cache_size_ = proto.l2_cache_size(); - clock_rate_ghz_ = proto.clock_rate_ghz(); - device_memory_size_ = proto.device_memory_size(); - registers_per_core_limit_ = proto.registers_per_core_limit(); - registers_per_block_limit_ = proto.registers_per_block_limit(); -} +DeviceDescription::DeviceDescription(const GpuDeviceInfoProto &proto) + : block_dim_limit_(BlockDim(proto.block_dim_limit_x(), + proto.block_dim_limit_y(), + proto.block_dim_limit_z())), + threads_per_core_limit_(proto.threads_per_core_limit()), + threads_per_block_limit_(proto.threads_per_block_limit()), + threads_per_warp_(proto.threads_per_warp()), + registers_per_core_limit_(proto.registers_per_core_limit()), + registers_per_block_limit_(proto.registers_per_block_limit()), + device_memory_size_(proto.device_memory_size()), + l2_cache_size_(proto.l2_cache_size()), + memory_bandwidth_(proto.memory_bandwidth()), + shared_memory_per_core_(proto.shared_memory_per_core()), + shared_memory_per_block_(proto.shared_memory_per_block()), + shared_memory_per_block_optin_(proto.shared_memory_per_block_optin()), + clock_rate_ghz_(proto.clock_rate_ghz()), + gpu_compute_capability_( + proto.has_cuda_compute_capability() + ? GpuComputeCapability(stream_executor::CudaComputeCapability( + proto.cuda_compute_capability())) + : GpuComputeCapability(stream_executor::RocmComputeCapability( + proto.rocm_compute_capability()))), + core_count_(proto.core_count()), + fpus_per_core_(proto.fpus_per_core()) {} GpuDeviceInfoProto DeviceDescription::ToGpuProto() const { stream_executor::GpuDeviceInfoProto proto; diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index 06401c36423612..9b7511747e0845 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -457,53 +457,59 @@ class DeviceDescription { // For string values that are not available via the underlying platform, this // value will be provided. - static const char *kUndefinedString; + static inline const char *const kUndefinedString = ""; private: friend class internal::DeviceDescriptionBuilder; - DeviceDescription(); + DeviceDescription() = default; // For description of the following members, see the corresponding accessor // above. // // N.B. If another field is added, update ToMap() above. - std::string device_vendor_; - std::string platform_version_; - std::string driver_version_; - std::string runtime_version_; - std::string pci_bus_id_; - std::string name_; - std::string model_str_; - - ThreadDim thread_dim_limit_; - BlockDim block_dim_limit_; - - int64_t threads_per_core_limit_; - int64_t threads_per_block_limit_; - int64_t threads_per_warp_; - - int64_t registers_per_core_limit_; - int64_t registers_per_block_limit_; - - int64_t device_address_bits_; - int64_t device_memory_size_; - int64_t l2_cache_size_; - int64_t memory_bandwidth_; + std::string device_vendor_ = kUndefinedString; + std::string platform_version_ = kUndefinedString; + std::string driver_version_ = kUndefinedString; + std::string runtime_version_ = kUndefinedString; + std::string pci_bus_id_ = kUndefinedString; + std::string name_ = kUndefinedString; + std::string model_str_ = kUndefinedString; + + template + static constexpr T kUninitialized = T(-1); + + ThreadDim thread_dim_limit_{kUninitialized, + kUninitialized, + kUninitialized}; + BlockDim block_dim_limit_{kUninitialized, kUninitialized, + kUninitialized}; + + int64_t threads_per_core_limit_ = kUninitialized; + int64_t threads_per_block_limit_ = kUninitialized; + int64_t threads_per_warp_ = kUninitialized; + + int64_t registers_per_core_limit_ = kUninitialized; + int64_t registers_per_block_limit_ = kUninitialized; + + int64_t device_address_bits_ = kUninitialized; + int64_t device_memory_size_ = kUninitialized; + int64_t l2_cache_size_ = kUninitialized; + int64_t memory_bandwidth_ = kUninitialized; // Shared memory limits on a given device. - int64_t shared_memory_per_core_; - int64_t shared_memory_per_block_; - int64_t shared_memory_per_block_optin_; + int64_t shared_memory_per_core_ = kUninitialized; + int64_t shared_memory_per_block_ = kUninitialized; + int64_t shared_memory_per_block_optin_ = kUninitialized; - float clock_rate_ghz_; + float clock_rate_ghz_ = kUninitialized; - GpuComputeCapability gpu_compute_capability_; + GpuComputeCapability gpu_compute_capability_{}; - int numa_node_; - int core_count_; - int fpus_per_core_; - bool ecc_enabled_; + int numa_node_ = kUninitialized; + int core_count_ = kUninitialized; + int fpus_per_core_ = kUninitialized; + bool ecc_enabled_ = false; }; namespace internal { From cd906c188f9c404ccc5e28f6b9e6f1a3a891ea8f Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Fri, 19 Jul 2024 10:22:11 -0700 Subject: [PATCH 021/376] [XLA:GPU] Add arbitrary axis support to the Triton reduce emitter. PiperOrigin-RevId: 654046728 --- .../fusions/triton/triton_fusion_emitter.cc | 67 +++++---- .../triton_fusion_emitter_device_test.cc | 141 +++++++++++++++++- .../gpu/fusions/triton/triton_support.cc | 4 +- .../gpu/fusions/triton/triton_support_test.cc | 9 +- 4 files changed, 184 insertions(+), 37 deletions(-) diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 9be712a4f86959..13bf68ed1789df 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -674,24 +674,20 @@ absl::StatusOr EmitReduce( absl::flat_hash_map& values, absl::string_view libdevice_path, const se::DeviceDescription& device_info) { + // At the moment, we should only emit a full reduction over a single + // dimension using a scalar as a neutral element. const HloReduceInstruction& hlo_reduce = *::xla::Cast(tiled_hlo_reduce.hlo()); - // At the moment, we should only emit a full reduction over the last axis of - // a single input. TF_RET_CHECK(hlo_reduce.operand_count() == 2); TF_RET_CHECK(hlo_reduce.dimensions().size() == 1); - TF_RET_CHECK(hlo_reduce.dimensions().front() == - hlo_reduce.operand(0)->shape().rank() - 1); - const int64_t row_len = hlo_reduce.operand(0)->shape().dimensions_minor(0); - const int64_t block_row = llvm::PowerOf2Ceil(row_len); Value input = values[tiled_hlo_reduce.operand(0)]; - Value neutral = values[tiled_hlo_reduce.operand(1)]; - llvm::ArrayRef input_shape = - mlir::cast(values[tiled_hlo_reduce.operand(0)].getType()) - .getShape(); - int64_t input_rank = input_shape.size(); + mlir::cast(input.getType()).getShape(); + absl::Span source_tensor_shape = + hlo_reduce.operand(0)->shape().dimensions(); + + int reduction_dimension = hlo_reduce.dimensions().front(); // Since every shape is padded to a power of 2 in Triton, the input tile may // be padded with arbitrary values. These values could affect the result of @@ -701,29 +697,42 @@ absl::StatusOr EmitReduce( // hlo_reduce.operand(1) is thus always the right choice to ensure that the // reduction is computed correctly, since it is the neutral value with regards // to the reducer. - if (block_row != row_len) { - Value mask = b.create( - ma::CmpIPredicate::slt, Range(b, block_row), - Splat(b, CreateConst(b, b.getI32Type(), row_len), block_row)); - - // Make the mask match the rank of the input---the mask starts out with - // rank 1. - mask = LeftExpandDimNTimes(b, mask, input_rank - 1); - mask = Broadcast(b, mlir::cast(mask), input_shape); - - Value broadcasted_neutral = Broadcast( - b, mlir::cast(LeftExpandDimNTimes(b, neutral, input_rank)), - input_shape); - - input = b.create(mask, input, broadcasted_neutral); + int64_t source_tensor_reduction_dimension_size = + source_tensor_shape[reduction_dimension]; + int64_t input_reduction_dimension_size = input_shape[reduction_dimension]; + if (input_reduction_dimension_size != + source_tensor_reduction_dimension_size) { + Value range = Range(b, input_reduction_dimension_size); + // Triton's broadcast requires that the rank of the source and broadcasted + // result are equal. + for (int i = 0; i < input_shape.size() - 1; i++) { + if (i < reduction_dimension) { + range = b.create(range, /*axis=*/0); + } else { + range = b.create(range, /*axis=*/i + 1); + } + } + Value mask = Broadcast(b, mlir::cast(range), input_shape); + Value constant = + CreateConst(b, b.getI32Type(), source_tensor_reduction_dimension_size); + Value constant_tensor = Splat(b, constant, input_shape); + mask = b.create(ma::CmpIPredicate::slt, mask, constant_tensor); + + Value neutral = values[tiled_hlo_reduce.operand(1)]; + // Triton's broadcast requires that the rank of the source and broadcasted + // result are equal. + for (int i = 0; i < input_shape.size(); i++) { + neutral = b.create(neutral, /*axis=*/0); + } + neutral = Broadcast(b, mlir::cast(neutral), input_shape); + input = b.create(mask, input, neutral); } // Triton actually only performs reductions on float32 inputs, and we must // thus upcast/downcast our input if its data type is different. - Value casted_input = Cast(b, input, b.getF32Type()); + input = Cast(b, input, b.getF32Type()); - mt::ReduceOp reduction = b.create( - SmallVector({casted_input}), (int)input_shape.size() - 1); + mt::ReduceOp reduction = b.create(input, reduction_dimension); { mlir::Location loc = b.getLoc(); mlir::Block* reducer = diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index 4db67018e29f63..9ca1b90100e0a9 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -58,7 +58,138 @@ class TritonEmitterTest : public GpuCodegenTest { } }; -TEST_F(TritonEmitterTest, TestGenericEmitterWithSingleParameter) { +TEST_F(TritonEmitterTest, ReductionOnMinormostAxisIsEmittedCorrectly) { + const std::string kHloText = R"( +HloModule t +maximum { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(Arg_0, Arg_1) +} + +triton_reduction_computation { + parameter_0 = f32[8,4] parameter(0) + constant_0 = f32[] constant(0) + ROOT reduce = f32[8] reduce(parameter_0, constant_0), dimensions={1}, to_apply=maximum +} + +ENTRY main { + param_0 = f32[8,4] parameter(0) + ROOT triton_reduction = f32[8] fusion(param_0), kind=kCustom, calls=triton_reduction_computation, backend_config={"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["4"],"num_warps":"1"}}} +})"; + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({4}), + "triton_reduction_computation", R"( +CHECK: "tt.reduce"(%[[LOAD:.*]]) <{axis = 1 : i32}> +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(TritonEmitterTest, ReductionOnMajormostAxisIsEmittedCorrectly) { + const std::string kHloText = R"( +HloModule t +maximum { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(Arg_0, Arg_1) +} + +triton_reduction_computation { + parameter_0 = f32[8,4] parameter(0) + constant_0 = f32[] constant(0) + ROOT reduce = f32[4] reduce(parameter_0, constant_0), dimensions={0}, to_apply=maximum +} + +ENTRY main { + param_0 = f32[8,4] parameter(0) + ROOT triton_reduction = f32[4] fusion(param_0), kind=kCustom, calls=triton_reduction_computation, backend_config={"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["4"],"num_warps":"1"}}} +})"; + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({4}), + "triton_reduction_computation", R"( +CHECK: "tt.reduce"(%[[LOAD:.*]]) <{axis = 0 : i32}> +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(TritonEmitterTest, ReductionOnIntermediateAxisIsEmittedCorrectly) { + const std::string kHloText = R"( +HloModule t +maximum { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(Arg_0, Arg_1) +} + +triton_reduction_computation { + parameter_0 = f32[5,5,5,5,3] parameter(0) + constant_0 = f32[] constant(0) + ROOT reduction = f32[5,5,5,3] reduce(parameter_0, constant_0), dimensions={2}, to_apply=maximum +} + +ENTRY main { + param_0 = f32[5,5,5,5,3] parameter(0) + ROOT triton_reduction = f32[5,5,5,3] fusion(param_0), kind=kCustom, calls=triton_reduction_computation, backend_config={"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["4", "2", "5", "1"],"num_warps":"1"}}} +})"; + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({4, 2, 5, 1}), + "triton_reduction_computation", R"( +CHECK: tt.make_range +CHECK-COUNT-4: tt.expand_dims +CHECK: "tt.reduce"(%[[SELECT:.*]]) <{axis = 2 : i32}> +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(TritonEmitterTest, TestReductionWithTileSizeLargerThanSourceTensor) { + const std::string kHloText = R"( +HloModule t +maximum { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(Arg_0, Arg_1) +} + +triton_reduction_computation { + parameter_0 = f32[5,3] parameter(0) + constant_0 = f32[] constant(0) + ROOT reduce = f32[3] reduce(parameter_0, constant_0), dimensions={0}, to_apply=maximum +} + +ENTRY main { + param_0 = f32[5,3] parameter(0) + ROOT triton_reduction = f32[3] fusion(param_0), kind=kCustom, calls=triton_reduction_computation, backend_config={"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["3"],"num_warps":"1"}}} +})"; + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, + FromOutputTileSizes({3}), + "triton_reduction_computation", R"( +; Make sure input reduction tile is padded with a neutral value. +CHECK: %[[LOAD:.*]] = tt.load +CHECK: %[[RANGE:.*]] = tt.make_range +CHECK: %[[EXPAND:.*]] = tt.expand_dims %[[RANGE]] +CHECK: %[[BROADCAST:.*]] = tt.broadcast %[[EXPAND]] +CHECK: %[[CMPI:.*]] = arith.cmpi slt, %[[BROADCAST]] +CHECK: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[LOAD]] +CHECK: "tt.reduce"(%[[SELECT]]) <{axis = 0 : i32}> +CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32): +CHECK: %[[MAXIMUM:.*]] = arith.maximumf %[[ARG2]], %[[ARG3]] : f32 +CHECK: tt.reduce.return %[[MAXIMUM]] : f32 +CHECK: }) +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. +TEST_F(TritonEmitterTest, TestSoftmaxEmitterWithSingleParameter) { const std::string kHloText = R"( HloModule t add { @@ -109,6 +240,8 @@ CHECK: } )")); } +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterWithMultipleParameters) { const std::string kHloText = R"( HloModule t @@ -432,6 +565,8 @@ ENTRY main { RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. TEST_F(TritonEmitterTest, EmitterFailsIfComputeCapabilityIsBelowAmpere) { const std::string kHloText = R"( triton_computation { @@ -472,6 +607,8 @@ ENTRY entry { "(compute capability 8.0) and up, but got"))); } +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterReductionFusion) { const std::string kHloText = R"( HloModule t @@ -510,6 +647,8 @@ CHECK: tt.store {{.*}} : !tt.ptr> )")); } +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be +// moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterWithSoftMaxSingleParameter) { const std::string kHloText = R"( HloModule t diff --git a/xla/service/gpu/fusions/triton/triton_support.cc b/xla/service/gpu/fusions/triton/triton_support.cc index 0565ecf23fff8b..8f5246d6cd80ac 100644 --- a/xla/service/gpu/fusions/triton/triton_support.cc +++ b/xla/service/gpu/fusions/triton/triton_support.cc @@ -309,9 +309,7 @@ CodegenDecision CanTritonHandleReduce( return "Unsupported reduction computation by Triton."; } - if (reduce.dimensions().size() == 1 && - reduce.dimensions().front() == reduce.operand(0)->shape().rank() - 1 && - reduce.operand_count() == 2) { + if (reduce.dimensions().size() == 1 && reduce.operand_count() == 2) { return CodegenDecision{}; } return "Reduction is not a row-reduction of a single operand."; diff --git a/xla/service/gpu/fusions/triton/triton_support_test.cc b/xla/service/gpu/fusions/triton/triton_support_test.cc index b8a2d25f42560f..d93817af5efc6e 100644 --- a/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -125,7 +125,8 @@ bool DoesOpSupportType(HloOpcode opcode, PrimitiveType type) { case HloOpcode::kNegate: return type != PRED; default: - // Returning true by default ensures that newly added ops are not skipped. + // Returning true by default ensures that newly added ops are not + // skipped. return true; } } @@ -466,8 +467,8 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } -TEST_P(ReduceTest, - UnsupportedReduceWithNonLastReduceDimensionFailsGracefullyWithTriton) { +TEST_P(ReduceTest, IsTritonSupportedReduceWithNonLastReduceDimension) { + GTEST_SKIP() << "TODO(b/348565795): this test is currently broken."; auto [data_type, opcode, cc] = GetParam(); bool dtype_is_complex = data_type == C64 || data_type == C128; const std::string kHloTestTemplate = @@ -487,7 +488,7 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - EXPECT_FALSE(IsTritonSupportedInstruction(ti.Instruction(), cc)); + EXPECT_TRUE(IsTritonSupportedInstruction(ti.Instruction(), cc)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } From 58233288b9404e2cab1fbb5e543bfb7697da740e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 19 Jul 2024 10:43:52 -0700 Subject: [PATCH 022/376] Reverts 78a16ba2fa33a13e44e2162c3dbe065150168915 PiperOrigin-RevId: 654055008 --- xla/service/BUILD | 23 ---- xla/service/gpu/BUILD | 1 - xla/service/gpu/gpu_compiler.cc | 2 - xla/service/unique_channel_id_enforcer.cc | 56 --------- xla/service/unique_channel_id_enforcer.h | 43 ------- .../unique_channel_id_enforcer_test.cc | 108 ------------------ 6 files changed, 233 deletions(-) delete mode 100644 xla/service/unique_channel_id_enforcer.cc delete mode 100644 xla/service/unique_channel_id_enforcer.h delete mode 100644 xla/service/unique_channel_id_enforcer_test.cc diff --git a/xla/service/BUILD b/xla/service/BUILD index 38bb3ff4cc74b9..a255855b928b36 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -6227,29 +6227,6 @@ xla_cc_test( ], ) -cc_library( - name = "unique_channel_id_enforcer", - srcs = ["unique_channel_id_enforcer.cc"], - hdrs = ["unique_channel_id_enforcer.h"], - deps = [ - ":hlo_pass", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/status:statusor", - ], -) - -xla_cc_test( - name = "unique_channel_id_enforcer_test", - srcs = ["unique_channel_id_enforcer_test.cc"], - deps = [ - ":hlo_parser", - ":unique_channel_id_enforcer", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "root_instruction_sinker", srcs = ["root_instruction_sinker.cc"], diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 3eee391f579cc7..93dd425ac6d8e8 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3263,7 +3263,6 @@ cc_library( "//xla/service:topk_rewriter", "//xla/service:transpose_folding", "//xla/service:tuple_simplifier", - "//xla/service:unique_channel_id_enforcer", "//xla/service:while_loop_all_reduce_code_motion", "//xla/service:while_loop_constant_sinking", "//xla/service:while_loop_simplifier", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 44aa2e8474cc51..3973e92b144a9b 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -221,7 +221,6 @@ limitations under the License. #include "xla/service/topk_rewriter.h" #include "xla/service/transpose_folding.h" #include "xla/service/tuple_simplifier.h" -#include "xla/service/unique_channel_id_enforcer.h" #include "xla/service/while_loop_all_reduce_code_motion.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/service/while_loop_simplifier.h" @@ -2353,7 +2352,6 @@ absl::Status GpuCompiler::RunPreSchedulingPasses( HloModule* module, se::StreamExecutor* stream_exec) { HloPassPipeline pipeline("pre-scheduling-passes"); pipeline.AddPass(); - pipeline.AddPass(); return pipeline.Run(module).status(); } diff --git a/xla/service/unique_channel_id_enforcer.cc b/xla/service/unique_channel_id_enforcer.cc deleted file mode 100644 index 4762961fa07ea0..00000000000000 --- a/xla/service/unique_channel_id_enforcer.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/unique_channel_id_enforcer.h" - -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/utils/hlo_query.h" - -namespace xla { - -absl::StatusOr UniqueChannelIdEnforcer::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - absl::flat_hash_set> used_channel_ids; - auto next_channel_id = hlo_query::NextChannelId(*module); - bool changed = false; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (!hlo_query::IsCollectiveCommunicationOp(instruction->opcode())) - continue; - auto channel_id = instruction->channel_id(); - if (used_channel_ids.contains(channel_id)) { - if (assert_unique_channel_ids_) { - LOG(ERROR) << "Duplicate channel ID " << channel_id.value_or(-1) - << " found on instruction: " << instruction->ToString(); - return absl::InternalError(absl::StrFormat( - "Duplicate channel ID %d found on instruction: %s", - channel_id.value_or(-1), instruction->ToString())); - } - instruction->set_channel_id(next_channel_id); - used_channel_ids.insert(next_channel_id); - next_channel_id++; - changed = true; - } else { - used_channel_ids.insert(channel_id); - } - } - } - - return changed; -} - -} // namespace xla diff --git a/xla/service/unique_channel_id_enforcer.h b/xla/service/unique_channel_id_enforcer.h deleted file mode 100644 index e64d49a40858c9..00000000000000 --- a/xla/service/unique_channel_id_enforcer.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_UNIQUE_CHANNEL_ID_ENFORCER_H_ -#define XLA_SERVICE_UNIQUE_CHANNEL_ID_ENFORCER_H_ - -#include "xla/service/hlo_pass_interface.h" - -namespace xla { -// A pass which enforces that every collective -// must have a unique channel id. -class UniqueChannelIdEnforcer : public HloModulePass { - public: - explicit UniqueChannelIdEnforcer(bool assert_unique_channel_ids = false) - : assert_unique_channel_ids_(assert_unique_channel_ids) {} - - absl::string_view name() const override { - return "unique-channel-id-enforcer"; - } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - bool assert_unique_channel_ids_; -}; - -} // namespace xla - -#endif // XLA_SERVICE_UNIQUE_CHANNEL_ID_ENFORCER_H_ diff --git a/xla/service/unique_channel_id_enforcer_test.cc b/xla/service/unique_channel_id_enforcer_test.cc deleted file mode 100644 index ff2ae49b8fcc24..00000000000000 --- a/xla/service/unique_channel_id_enforcer_test.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/unique_channel_id_enforcer.h" - -#include "xla/service/hlo_parser.h" -#include "xla/tests/hlo_test_base.h" - -namespace xla { -namespace { - -using UniqueChannelIdEnforcerTest = HloTestBase; - -TEST_F(UniqueChannelIdEnforcerTest, EnsureUniqueChannelIdsAllGather) { - const char* const hlo_string = R"( -HloModule Module - -ENTRY entry { - param0 = f32[8] parameter(0) - param1 = f32[8] parameter(1) - allgather0 = f32[32] all-gather(param0), channel_id=1, replica_groups={}, dimensions={0} - allgather1 = f32[32] all-gather(param1), channel_id=1, replica_groups={}, dimensions={0} - ROOT tuple = (f32[32], f32[32]) tuple(allgather0, allgather1) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - UniqueChannelIdEnforcer enforcer; - TF_ASSERT_OK_AND_ASSIGN(bool changed, enforcer.Run(module.get())); - EXPECT_TRUE(changed); - - // Verify that channel IDs are unique for all-gather ops - std::optional all_gather1_channel_id; - std::optional all_gather2_channel_id; - - for (HloInstruction* inst : module->entry_computation()->instructions()) { - if (inst->opcode() == HloOpcode::kAllGather) { - if (!all_gather1_channel_id.has_value()) { - all_gather1_channel_id = inst->channel_id(); - } else { - all_gather2_channel_id = inst->channel_id(); - } - } - } - - ASSERT_TRUE(all_gather1_channel_id.has_value()); - ASSERT_TRUE(all_gather2_channel_id.has_value()); - EXPECT_NE(all_gather1_channel_id.value(), all_gather2_channel_id.value()); -} - -TEST_F(UniqueChannelIdEnforcerTest, ChannelIdsAlreadyUnique) { - const char* const hlo_string = R"( -HloModule Module - -ENTRY entry { - param0 = f32[8] parameter(0) - param1 = f32[8] parameter(1) - allgather0 = f32[32] all-gather(param0), channel_id=1, replica_groups={}, dimensions={0} - allgather1 = f32[32] all-gather(param1), channel_id=2, replica_groups={}, dimensions={0} - ROOT tuple = (f32[32], f32[32]) tuple(allgather0, allgather1) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - UniqueChannelIdEnforcer enforcer; - TF_ASSERT_OK_AND_ASSIGN(bool changed, enforcer.Run(module.get())); - EXPECT_FALSE(changed); -} - -TEST_F(UniqueChannelIdEnforcerTest, DuplicateChannelIdsAssertTrue) { - const char* const hlo_string = R"( - HloModule Module - - ENTRY entry { - param0 = f32[8] parameter(0) - param1 = f32[8] parameter(1) - allgather0 = f32[32] all-gather(param0), channel_id=1, replica_groups={}, dimensions={0} - allgather1 = f32[32] all-gather(param1), channel_id=1, replica_groups={}, dimensions={0} - ROOT tuple = (f32[32], f32[32]) tuple(allgather0, allgather1) - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - UniqueChannelIdEnforcer enforcer(/*assert_unique_channel_ids=*/true); - auto status_or_changed = enforcer.Run(module.get()); - - EXPECT_FALSE(status_or_changed.ok()); -} - -} // namespace -} // namespace xla From 50849afacd3c4bba2e18c0410fc7f040868cbecc Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 19 Jul 2024 10:53:38 -0700 Subject: [PATCH 023/376] [XLA:Python] More fixes for nanobind 2.0. Mark several more enum types as arithmetic since some downstream users cast to and from integer types. PiperOrigin-RevId: 654058324 --- xla/python/ops.cc | 5 +++-- xla/python/xla.cc | 2 +- xla/python/xla_compiler.cc | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/xla/python/ops.cc b/xla/python/ops.cc index ebcbb44781f36b..87f5333645a702 100644 --- a/xla/python/ops.cc +++ b/xla/python/ops.cc @@ -297,7 +297,7 @@ void BuildOpsSubmodule(nb::module_& m) { .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) .value("ADJOINT", TriangularSolveOptions::ADJOINT); - nb::enum_(ops, "RandomAlgorithm") + nb::enum_(ops, "RandomAlgorithm", nb::is_arithmetic()) .value("RNG_DEFAULT", RandomAlgorithm::RNG_DEFAULT) .value("RNG_THREE_FRY", RandomAlgorithm::RNG_THREE_FRY) .value("RNG_PHILOX", RandomAlgorithm::RNG_PHILOX); @@ -307,7 +307,8 @@ void BuildOpsSubmodule(nb::module_& m) { .value("SCHEDULE_LATEST", CustomCallSchedule::SCHEDULE_LATEST) .value("SCHEDULE_EARLIEST", CustomCallSchedule::SCHEDULE_EARLIEST); - nb::enum_(ops, "CustomCallApiVersion") + nb::enum_(ops, "CustomCallApiVersion", + nb::is_arithmetic()) .value("API_VERSION_ORIGINAL", CustomCallApiVersion::API_VERSION_ORIGINAL) .value("API_VERSION_STATUS_RETURNING", CustomCallApiVersion::API_VERSION_STATUS_RETURNING) diff --git a/xla/python/xla.cc b/xla/python/xla.cc index ec09a0395b2506..19a9d94e1d1b7b 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -182,7 +182,7 @@ NB_MODULE(xla_extension, m_nb) { PyExc_RuntimeError); // Types - nb::enum_(m_nb, "PrimitiveType") + nb::enum_(m_nb, "PrimitiveType", nb::is_arithmetic()) .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) .value("PRED", PRED) .value("S4", S4) diff --git a/xla/python/xla_compiler.cc b/xla/python/xla_compiler.cc index 5dbeeaac60b578..3be39c7e43eb65 100644 --- a/xla/python/xla_compiler.cc +++ b/xla/python/xla_compiler.cc @@ -1275,7 +1275,8 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { options.set_allow_spmd_sharding_propagation_to_output(v); }); - nb::enum_ op_sharding_type(m, "OpSharding_Type"); + nb::enum_ op_sharding_type(m, "OpSharding_Type", + nb::is_arithmetic()); op_sharding_type.value("REPLICATED", OpSharding::REPLICATED) .value("MAXIMAL", OpSharding::MAXIMAL) .value("MANUAL", OpSharding::MANUAL) From 253b924596abb312bad9e6de241157ea0d4121fd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 19 Jul 2024 11:00:48 -0700 Subject: [PATCH 024/376] Reverts afc28e65255c3018843682000faf322a05e06c3d PiperOrigin-RevId: 654061294 --- xla/service/hlo_cse.cc | 3 --- xla/service/hlo_cse.h | 16 ++++++---------- xla/service/hlo_cse_test.cc | 23 ----------------------- xla/service/sharding_propagation.cc | 3 +-- 4 files changed, 7 insertions(+), 38 deletions(-) diff --git a/xla/service/hlo_cse.cc b/xla/service/hlo_cse.cc index 68d1aef8ca3da1..8f7810c9be2772 100644 --- a/xla/service/hlo_cse.cc +++ b/xla/service/hlo_cse.cc @@ -312,9 +312,6 @@ absl::StatusOr HloCSE::Run( representatives(/*N=*/computation->instruction_count() + 1, absl::Hash{}, cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { - if (instructions_to_skip_.contains(instruction)) { - continue; - } // If the instruction has zero operands (constants, parameters, etc.) skip // over it. if (instruction->operand_count() == 0 && diff --git a/xla/service/hlo_cse.h b/xla/service/hlo_cse.h index 15a62bebb782a6..03496234c007c8 100644 --- a/xla/service/hlo_cse.h +++ b/xla/service/hlo_cse.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_CSE_H_ #define XLA_SERVICE_HLO_CSE_H_ -#include "absl/container/flat_hash_set.h" -#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -33,18 +31,17 @@ class HloCSE : public HloModulePass { // transformation. Otherwise, layout is ignored. // If ignore_control_dependencies is true, the pass will ignore control deps // when replacing instructions with their equivalents. - explicit HloCSE( - bool is_layout_sensitive, bool only_fusion_computations = false, - bool ignore_control_dependencies = false, bool only_scalars = false, - bool is_sharding_sensitive = true, bool allow_compatible_sharding = false, - absl::flat_hash_set instructions_to_skip = {}) + explicit HloCSE(bool is_layout_sensitive, + bool only_fusion_computations = false, + bool ignore_control_dependencies = false, + bool only_scalars = false, bool is_sharding_sensitive = true, + bool allow_compatible_sharding = false) : is_layout_sensitive_(is_layout_sensitive), only_fusion_computations_(only_fusion_computations), ignore_control_dependencies_(ignore_control_dependencies), only_scalars_(only_scalars), is_sharding_sensitive_(is_sharding_sensitive), - allow_compatible_sharding_(allow_compatible_sharding), - instructions_to_skip_(instructions_to_skip) {} + allow_compatible_sharding_(allow_compatible_sharding) {} ~HloCSE() override = default; absl::string_view name() const override { return "cse"; } @@ -62,7 +59,6 @@ class HloCSE : public HloModulePass { const bool only_scalars_; const bool is_sharding_sensitive_; const bool allow_compatible_sharding_; - absl::flat_hash_set instructions_to_skip_; }; } // namespace xla diff --git a/xla/service/hlo_cse_test.cc b/xla/service/hlo_cse_test.cc index e5d5630ed62d48..c08902d9826b00 100644 --- a/xla/service/hlo_cse_test.cc +++ b/xla/service/hlo_cse_test.cc @@ -543,29 +543,6 @@ ENTRY %entry { EXPECT_FALSE(cse.Run(m.get()).value()); } -TEST_F(HloCseTest, DoNotCombineSkippedOps) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - constant = bf16[] constant(0) - broadcast.0 = bf16[14,4,32768,3072]{3,2,1,0} broadcast(constant), dimensions={}, sharding={devices=[1,1,8,1,8]<=[64] last_tile_dim_replicate} - broadcast.1 = bf16[14,4,32768,3072]{3,2,1,0} broadcast(constant), dimensions={}, sharding={devices=[1,1,8,8]<=[64]} - ROOT tuple = (bf16[14,4,32768,3072]{3,2,1,0}, bf16[14,4,32768,3072]{3,2,1,0}) tuple(broadcast.0, broadcast.1) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* broadcast_0 = FindInstruction(m.get(), "broadcast.0"); - HloCSE cse(/*is_layout_sensitive=*/false, - /*only_fusion_computations=*/false, - /*ignore_control_dependencies=*/false, - /*only_scalars=*/false, - /*is_sharding_sensitive=*/true, - /*allow_compatible_sharding=*/true, - /*instructions_to_skip=*/{broadcast_0}); - EXPECT_FALSE(cse.Run(m.get()).value()); - XLA_VLOG_LINES(0, m->ToString()); -} - TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index ca78f979acaaf6..d6fac61352a6b0 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -3385,8 +3385,7 @@ absl::StatusOr ShardingPropagation::Run( /*ignore_control_dependencies=*/false, /*only_scalars=*/false, /*is_sharding_sensitive=*/true, - /*allow_compatible_sharding=*/true, - /*instructions_to_skip=*/provided_shardings); + /*allow_compatible_sharding=*/false); TF_RETURN_IF_ERROR(pass.Run(module, execution_threads).status()); // CSE may invalidate stored HloInstruction pointers, so we need to remove From 15a79c53e0bd38a6dcf3f268cf18b3120187ac9a Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Fri, 19 Jul 2024 11:04:21 -0700 Subject: [PATCH 025/376] Move Stream::DoHostCallbackWithStatus implementation to derived classes of Stream. This avoids eventual down_casts to the correct kind of stream in all the StreamExecutor derived classes that were happening. PiperOrigin-RevId: 654062693 --- xla/backends/interpreter/executor.cc | 6 ----- xla/backends/interpreter/executor.h | 3 --- xla/stream_executor/cuda/cuda_executor.cc | 19 --------------- xla/stream_executor/gpu/BUILD | 3 +++ xla/stream_executor/gpu/gpu_executor.h | 9 -------- xla/stream_executor/gpu/gpu_stream.cc | 25 ++++++++++++++++++++ xla/stream_executor/gpu/gpu_stream.h | 4 ++++ xla/stream_executor/host/host_executor.cc | 6 ----- xla/stream_executor/host/host_executor.h | 3 --- xla/stream_executor/host/host_stream.cc | 8 +++++++ xla/stream_executor/host/host_stream.h | 2 ++ xla/stream_executor/mock_stream_executor.h | 3 --- xla/stream_executor/rocm/rocm_executor.cc | 19 --------------- xla/stream_executor/stream.h | 9 ++++++-- xla/stream_executor/stream_common.cc | 16 ------------- xla/stream_executor/stream_common.h | 3 --- xla/stream_executor/stream_executor.h | 4 ---- xla/stream_executor/tpu/tpu_executor.cc | 20 ---------------- xla/stream_executor/tpu/tpu_executor.h | 3 --- xla/stream_executor/tpu/tpu_stream.h | 27 ++++++++++++++++++++-- 20 files changed, 74 insertions(+), 118 deletions(-) diff --git a/xla/backends/interpreter/executor.cc b/xla/backends/interpreter/executor.cc index 2476b8023ef219..b9efdb27d99303 100644 --- a/xla/backends/interpreter/executor.cc +++ b/xla/backends/interpreter/executor.cc @@ -56,12 +56,6 @@ absl::Status XlaInterpreterExecutor::SynchronousMemcpy( return absl::OkStatus(); } -bool XlaInterpreterExecutor::HostCallback( - Stream *stream, absl::AnyInvocable callback) { - AsExecutorStream(stream)->EnqueueTaskWithStatus(std::move(callback)); - return true; -} - absl::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { return AsExecutorStream(stream)->BlockUntilDone(); } diff --git a/xla/backends/interpreter/executor.h b/xla/backends/interpreter/executor.h index e12c11002ba055..8ca0cd9c357ef0 100644 --- a/xla/backends/interpreter/executor.h +++ b/xla/backends/interpreter/executor.h @@ -125,9 +125,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { const DeviceMemoryBase &dev_src, uint64_t size) override; - bool HostCallback(Stream *stream, - absl::AnyInvocable callback) override; - void DeallocateStream(Stream *stream) override {} absl::Status BlockHostUntilDone(Stream *stream) override; diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 26c36fd28025a0..bda62ea3a97300 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -575,25 +575,6 @@ absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -bool GpuExecutor::HostCallback(Stream* stream, - absl::AnyInvocable callback) { - auto callback_ptr = - new absl::AnyInvocable([cb = std::move(callback)]() mutable { - absl::Status s = std::move(cb)(); - if (!s.ok()) { - LOG(WARNING) << "Host callback failed: " << s; - } - }); - return GpuDriver::AddStreamCallback(context_, AsGpuStreamValue(stream), - InternalHostCallback, callback_ptr); -} - -/* static */ void GpuExecutor::InternalHostCallback(void* data) { - auto* callback = reinterpret_cast*>(data); - std::move (*callback)(); - delete callback; -} - void GpuExecutor::DeallocateStream(Stream* stream) { { absl::MutexLock lock(&mu_); diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 50127d6bd320c4..307eb68ab442eb 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -310,9 +310,11 @@ gpu_only_cc_library( ":gpu_executor_header", ":gpu_types_header", "//xla/stream_executor:device_memory", + "//xla/stream_executor:event", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", ], @@ -332,6 +334,7 @@ gpu_only_cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index f90eca074d24a7..65e53359898372 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -204,9 +204,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size) override; - bool HostCallback(Stream* stream, - absl::AnyInvocable callback) override; - void DeallocateStream(Stream* stream) override; absl::Status BlockHostUntilDone(Stream* stream) override; @@ -303,12 +300,6 @@ class GpuExecutor : public StreamExecutorCommon { uint64_t GetArgumentLoggingMode() const { return argument_logging_mode_; } private: - // Host callback landing routine invoked by CUDA. - // data: User-provided callback provided to HostCallback() above, captured - // as a std::function. Allocated/initialized inside - // HostCallback() and owned and deleted by this call. - static void InternalHostCallback(void* data); - // Collects metadata for the specified kernel. absl::Status GetKernelMetadata(GpuKernel* cuda_kernel, KernelMetadata* kernel_metadata); diff --git a/xla/stream_executor/gpu/gpu_stream.cc b/xla/stream_executor/gpu/gpu_stream.cc index de00ae6fa0571d..557ea215220e1b 100644 --- a/xla/stream_executor/gpu/gpu_stream.cc +++ b/xla/stream_executor/gpu/gpu_stream.cc @@ -16,8 +16,10 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_stream.h" #include +#include #include +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -36,6 +38,14 @@ limitations under the License. namespace stream_executor { namespace gpu { +namespace { +void InternalHostCallback(void* data) { + auto* callback = reinterpret_cast*>(data); + std::move (*callback)(); + delete callback; +} +} // namespace + bool GpuStream::Init() { int priority = [&]() { if (std::holds_alternative(stream_priority_)) { @@ -145,6 +155,21 @@ absl::Status GpuStream::WaitFor(Event* event) { "error recording waiting for event on stream %p", this)); } } +absl::Status GpuStream::DoHostCallbackWithStatus( + absl::AnyInvocable callback) { + auto callback_ptr = + new absl::AnyInvocable([cb = std::move(callback)]() mutable { + absl::Status s = std::move(cb)(); + if (!s.ok()) { + LOG(WARNING) << "Host callback failed: " << s; + } + }); + if (GpuDriver::AddStreamCallback(parent_->gpu_context(), gpu_stream(), + InternalHostCallback, callback_ptr)) { + return absl::OkStatus(); + } + return absl::InternalError("Failed to host callback."); +} void GpuStream::Destroy() { if (completed_event_ != nullptr) { diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index e066b4b9761c6b..c855a78f306172 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -22,9 +22,11 @@ limitations under the License. #include #include +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/strings/string_view.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" @@ -109,6 +111,8 @@ class GpuStream : public StreamCommon { uint64_t size) override; absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; + absl::Status DoHostCallbackWithStatus( + absl::AnyInvocable callback) override; void set_name(absl::string_view name) override; diff --git a/xla/stream_executor/host/host_executor.cc b/xla/stream_executor/host/host_executor.cc index 0748cc33d8012f..ac1d22583d0fde 100644 --- a/xla/stream_executor/host/host_executor.cc +++ b/xla/stream_executor/host/host_executor.cc @@ -167,12 +167,6 @@ absl::Status HostExecutor::SynchronousMemcpy(void* host_dst, return absl::OkStatus(); } -bool HostExecutor::HostCallback( - Stream* stream, absl::AnyInvocable callback) { - AsHostStream(stream)->EnqueueTaskWithStatus(std::move(callback)); - return true; -} - void HostExecutor::DeallocateStream(Stream* stream) {} absl::StatusOr> HostExecutor::CreateEvent() { diff --git a/xla/stream_executor/host/host_executor.h b/xla/stream_executor/host/host_executor.h index 182bbf22f9ad76..18ec5a739faca5 100644 --- a/xla/stream_executor/host/host_executor.h +++ b/xla/stream_executor/host/host_executor.h @@ -104,9 +104,6 @@ class HostExecutor : public StreamExecutorCommon { const DeviceMemoryBase& gpu_src, uint64_t size) override; - bool HostCallback(Stream* stream, - absl::AnyInvocable callback) override; - void DeallocateStream(Stream* stream) override; absl::Status BlockHostUntilDone(Stream* stream) override; diff --git a/xla/stream_executor/host/host_stream.cc b/xla/stream_executor/host/host_stream.cc index 61ea26d3dbcc33..ed6e040431e478 100644 --- a/xla/stream_executor/host/host_stream.cc +++ b/xla/stream_executor/host/host_stream.cc @@ -135,6 +135,14 @@ absl::Status HostStream::RecordEvent(Event* event) { return absl::OkStatus(); } +absl::Status HostStream::DoHostCallbackWithStatus( + absl::AnyInvocable callback) { + if (EnqueueTaskWithStatus(std::move(callback))) { + return absl::OkStatus(); + } + return absl::InternalError("Failed to host callback."); +} + bool HostStream::EnqueueTaskWithStatus( absl::AnyInvocable task) { CHECK(task != nullptr); diff --git a/xla/stream_executor/host/host_stream.h b/xla/stream_executor/host/host_stream.h index 15fdbcafaa253f..ed1bbc2011f48f 100644 --- a/xla/stream_executor/host/host_stream.h +++ b/xla/stream_executor/host/host_stream.h @@ -63,6 +63,8 @@ class HostStream : public StreamCommon { const DeviceMemoryBase& gpu_src, uint64_t size) override; absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) override; + absl::Status DoHostCallbackWithStatus( + absl::AnyInvocable callback) override; private: bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/xla/stream_executor/mock_stream_executor.h b/xla/stream_executor/mock_stream_executor.h index b5eb0754fdaca6..3787be1133b5d4 100644 --- a/xla/stream_executor/mock_stream_executor.h +++ b/xla/stream_executor/mock_stream_executor.h @@ -108,9 +108,6 @@ class MockStreamExecutor : public StreamExecutor { (Stream * stream, DeviceMemoryBase* location, uint8_t pattern, uint64_t size), (override)); - MOCK_METHOD(bool, HostCallback, - (Stream * stream, absl::AnyInvocable callback), - (override)); MOCK_METHOD(void, DeallocateStream, (Stream * stream), (override)); MOCK_METHOD(absl::Status, BlockHostUntilDone, (Stream * stream), (override)); MOCK_METHOD(absl::Status, EnablePeerAccessTo, (StreamExecutor * other), diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index b3cca7a3cf2531..d02c595a246dd2 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -484,25 +484,6 @@ absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, AsGpuStreamValue(stream)); } -bool GpuExecutor::HostCallback(Stream* stream, - absl::AnyInvocable callback) { - auto callback_ptr = - new absl::AnyInvocable([cb = std::move(callback)]() mutable { - absl::Status s = std::move(cb)(); - if (!s.ok()) { - LOG(WARNING) << "Host callback failed: " << s; - } - }); - return GpuDriver::AddStreamCallback(context_, AsGpuStreamValue(stream), - InternalHostCallback, callback_ptr); -} - -/* static */ void GpuExecutor::InternalHostCallback(void* data) { - auto* callback = reinterpret_cast*>(data); - std::move (*callback)(); - delete callback; -} - void GpuExecutor::DeallocateStream(Stream* stream) { { absl::MutexLock lock(&mu_); diff --git a/xla/stream_executor/stream.h b/xla/stream_executor/stream.h index aad2b86dfb9bf6..66625500c02eb9 100644 --- a/xla/stream_executor/stream.h +++ b/xla/stream_executor/stream.h @@ -22,6 +22,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_STREAM_H_ #include +#include #include #include "absl/functional/any_invocable.h" @@ -234,8 +235,12 @@ class Stream { // This is kept for backward compatibility. Future code should use // DoHostCallbackWithStatus and explicitly return a success status. // TODO(b/112125301): Eventually remove this method. - virtual absl::Status DoHostCallback( - absl::AnyInvocable callback) = 0; + absl::Status DoHostCallback(absl::AnyInvocable callback) { + return DoHostCallbackWithStatus([cb = std::move(callback)]() mutable { + std::move(cb)(); + return absl::OkStatus(); + }); + } // Entrains onto the stream a callback to the host (from the device). // Host callbacks block/occupy the stream just as device functions diff --git a/xla/stream_executor/stream_common.cc b/xla/stream_executor/stream_common.cc index adf662bf0d0aca..048623da37c01a 100644 --- a/xla/stream_executor/stream_common.cc +++ b/xla/stream_executor/stream_common.cc @@ -143,22 +143,6 @@ void StreamCommon::ReturnSubStream(Stream *sub_stream) { << sub_stream; } -absl::Status StreamCommon::DoHostCallback( - absl::AnyInvocable callback) { - return DoHostCallbackWithStatus([cb = std::move(callback)]() mutable { - std::move(cb)(); - return absl::OkStatus(); - }); -} - -absl::Status StreamCommon::DoHostCallbackWithStatus( - absl::AnyInvocable callback) { - if (parent_->HostCallback(this, std::move(callback))) { - return absl::OkStatus(); - } - return absl::InternalError("failed to host callback"); -} - void StreamCommon::CheckError(bool operation_retcode) { if (operation_retcode) { return; diff --git a/xla/stream_executor/stream_common.h b/xla/stream_executor/stream_common.h index 463d8589a9c986..3d2ade72ff12e3 100644 --- a/xla/stream_executor/stream_common.h +++ b/xla/stream_executor/stream_common.h @@ -69,9 +69,6 @@ class StreamCommon : public Stream { TF_LOCKS_EXCLUDED(mu_); void ReturnSubStream(Stream *sub_stream) override TF_LOCKS_EXCLUDED(mu_); absl::Status BlockHostUntilDone() override TF_LOCKS_EXCLUDED(mu_); - absl::Status DoHostCallback(absl::AnyInvocable callback) override; - absl::Status DoHostCallbackWithStatus( - absl::AnyInvocable callback) override; StreamExecutor *parent() const override { CHECK(parent_ != nullptr); return parent_; diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index c4265ffa44a2af..60fc20de835fb7 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -252,10 +252,6 @@ class StreamExecutor { return absl::InternalError("Not implemented"); } - // Enqueues on a stream a user-specified function to be run on the host. - virtual bool HostCallback(Stream* stream, - absl::AnyInvocable callback) = 0; - // Deallocates stream resources on the underlying platform. virtual void DeallocateStream(Stream* stream) = 0; diff --git a/xla/stream_executor/tpu/tpu_executor.cc b/xla/stream_executor/tpu/tpu_executor.cc index e17b318cfcd18f..596b4ee10c8d09 100644 --- a/xla/stream_executor/tpu/tpu_executor.cc +++ b/xla/stream_executor/tpu/tpu_executor.cc @@ -205,26 +205,6 @@ absl::Status TpuExecutor::EnqueueCompactionOnStreamForHbm( return status.status(); } -struct HostCallbackContext { - absl::AnyInvocable callback; -}; - -TSL_Status* HostCallbackTrampoline(void* ctx) { - HostCallbackContext* host_ctx = reinterpret_cast(ctx); - absl::Status status = std::move(host_ctx->callback)(); - TSL_Status* c_status = ExecutorApiFn()->TpuStatus_CreateFn( - status.raw_code(), absl::StatusMessageAsCStr(status)); - delete host_ctx; - return c_status; -} - -bool TpuExecutor::HostCallback(Stream* stream, - absl::AnyInvocable callback) { - HostCallbackContext* ctx = new HostCallbackContext{std::move(callback)}; - return ExecutorApiFn()->TpuExecutor_HostCallbackFn( - executor_, get_stream(stream), &HostCallbackTrampoline, ctx); -} - absl::StatusOr> TpuExecutor::CreateDeviceDescription() const { StatusHelper status; diff --git a/xla/stream_executor/tpu/tpu_executor.h b/xla/stream_executor/tpu/tpu_executor.h index 2aca25d2ace85c..d5b719787f4e6a 100644 --- a/xla/stream_executor/tpu/tpu_executor.h +++ b/xla/stream_executor/tpu/tpu_executor.h @@ -95,9 +95,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { absl::StatusOr> CreateEvent() override; - bool HostCallback(Stream* stream, - absl::AnyInvocable callback) override; - bool SynchronizeAllActivity() override; absl::Status SynchronousMemcpy(DeviceMemoryBase* device_dst, diff --git a/xla/stream_executor/tpu/tpu_stream.h b/xla/stream_executor/tpu/tpu_stream.h index 79bc6eef8cfbe4..298773caef6e2a 100644 --- a/xla/stream_executor/tpu/tpu_stream.h +++ b/xla/stream_executor/tpu/tpu_stream.h @@ -1,5 +1,3 @@ -#include "xla/stream_executor/event.h" -#include "xla/stream_executor/stream.h" /* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,9 +17,13 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_TPU_TPU_STREAM_H_ #include +#include +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/c_api_decl.h" @@ -146,6 +148,27 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { se_executor_, stream_, host_dst, &se_base, size, status.c_status); return status.status(); } + struct HostCallbackContext { + absl::AnyInvocable callback; + }; + static TSL_Status* HostCallbackTrampoline(void* ctx) { + HostCallbackContext* host_ctx = reinterpret_cast(ctx); + absl::Status status = std::move(host_ctx->callback)(); + TSL_Status* c_status = + stream_executor::tpu::ExecutorApiFn()->TpuStatus_CreateFn( + status.raw_code(), absl::StatusMessageAsCStr(status)); + delete host_ctx; + return c_status; + } + absl::Status DoHostCallbackWithStatus( + absl::AnyInvocable callback) override { + HostCallbackContext* ctx = new HostCallbackContext{std::move(callback)}; + if (stream_executor::tpu::ExecutorApiFn()->TpuExecutor_HostCallbackFn( + se_executor_, stream_, &HostCallbackTrampoline, ctx)) { + return absl::OkStatus(); + } + return absl::InternalError("Failed to host callback."); + } SE_Stream* se_stream() const { return stream_; } From e608ef43d7a4e719d40366a2f09e75ddd793862b Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Fri, 19 Jul 2024 11:16:12 -0700 Subject: [PATCH 026/376] Replace Copybara rule with `#if TSL_IS_IN_OSS` Current rule to replace `DownCastToGenerated` is only necessary due to old protobuf version PiperOrigin-RevId: 654066741 --- xla/service/BUILD | 1 + xla/service/compilation_environments.h | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/xla/service/BUILD b/xla/service/BUILD index a255855b928b36..75b2c53fe3fd91 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -7700,6 +7700,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", # fixdeps: keep diff --git a/xla/service/compilation_environments.h b/xla/service/compilation_environments.h index fd57a053947346..08a79df01e09cd 100644 --- a/xla/service/compilation_environments.h +++ b/xla/service/compilation_environments.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/xla.pb.h" #include "tsl/platform/casts.h" +#include "tsl/platform/platform.h" #include "tsl/platform/protobuf.h" namespace xla { @@ -143,7 +144,12 @@ T& CompilationEnvironments::GetMutableEnv() { it = environments_.find(descriptor); } + // TODO(b/302086111): Remove after XLA has an updated protobuf version. +#if TSL_IS_IN_OSS return tensorflow::down_cast(*it->second); +#else + return tsl::protobuf::DownCastToGenerated(*it->second); +#endif } template From 473b1d6ea18a7660495e57924d1dc6498eb0e057 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Fri, 19 Jul 2024 12:15:26 -0700 Subject: [PATCH 027/376] Rollback due to internal breaking change Reverts c8c376e96b5e3b23ecb150d574d96e19448d7a8c PiperOrigin-RevId: 654086100 --- xla/hlo/ir/BUILD | 12 +- xla/hlo/ir/hlo_instruction.cc | 5 +- xla/hlo/ir/hlo_instruction_utils.cc | 17 --- xla/hlo/ir/hlo_instruction_utils.h | 4 - xla/hlo/transforms/BUILD | 36 ----- xla/hlo/transforms/hlo_broadcast_splitter.cc | 57 -------- xla/hlo/transforms/hlo_broadcast_splitter.h | 41 ------ .../transforms/hlo_broadcast_splitter_test.cc | 121 ----------------- xla/service/BUILD | 6 +- xla/service/gpu/BUILD | 1 + xla/service/hlo_cse.cc | 84 +++--------- xla/service/hlo_cse.h | 9 +- xla/service/hlo_cse_test.cc | 102 -------------- xla/service/sharding_propagation.cc | 125 ------------------ xla/service/sharding_propagation_test.cc | 3 +- 15 files changed, 33 insertions(+), 590 deletions(-) delete mode 100644 xla/hlo/transforms/hlo_broadcast_splitter.cc delete mode 100644 xla/hlo/transforms/hlo_broadcast_splitter.h delete mode 100644 xla/hlo/transforms/hlo_broadcast_splitter_test.cc diff --git a/xla/hlo/ir/BUILD b/xla/hlo/ir/BUILD index b9cbfe8d14ac56..b396fbee70940d 100644 --- a/xla/hlo/ir/BUILD +++ b/xla/hlo/ir/BUILD @@ -28,7 +28,6 @@ cc_library( "hlo_frontend_attributes.cc", "hlo_input_output_alias_config.cc", "hlo_instruction.cc", - "hlo_instruction_utils.cc", "hlo_instructions.cc", "hlo_module.cc", "hlo_module_metadata.cc", @@ -50,7 +49,6 @@ cc_library( "hlo_frontend_attributes.h", "hlo_input_output_alias_config.h", "hlo_instruction.h", - "hlo_instruction_utils.h", "hlo_instructions.h", "hlo_module.h", "hlo_module_metadata.h", @@ -156,6 +154,16 @@ cc_library( ], ) +cc_library( + name = "hlo_instruction_utils", + srcs = ["hlo_instruction_utils.cc"], + hdrs = ["hlo_instruction_utils.h"], + deps = [ + ":hlo", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( name = "hlo_reachability", srcs = ["hlo_reachability.cc"], diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 7f9ea776a661d4..97aea062027a09 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -56,7 +56,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_frontend_attributes.h" -#include "xla/hlo/ir/hlo_instruction_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_op_metadata.h" @@ -2789,8 +2788,8 @@ bool HloInstruction::IdenticalInternal( : ShapeUtil::Compatible(shape(), other.shape()))) { return false; } - if (sharding_sensitive && - !hlo_instruction_utils::HasEquivalentShardings(*this, other)) { + if (sharding_sensitive && has_sharding() && other.has_sharding() && + sharding() != other.sharding()) { return false; } if (operands().size() != other.operands().size()) { diff --git a/xla/hlo/ir/hlo_instruction_utils.cc b/xla/hlo/ir/hlo_instruction_utils.cc index 286d60d71910ac..52eac784f085d5 100644 --- a/xla/hlo/ir/hlo_instruction_utils.cc +++ b/xla/hlo/ir/hlo_instruction_utils.cc @@ -49,22 +49,5 @@ void AddOrUpdateVectorOfPairsAsAttribute(HloInstruction* instr, instr->set_frontend_attributes(attributes); } -bool HasEquivalentShardings(const HloInstruction& lhs, - const HloInstruction& rhs) { - if (lhs.has_sharding() && !lhs.sharding().IsReplicated() && - !rhs.has_sharding()) { - return false; - } - if (!lhs.has_sharding() && rhs.has_sharding() && - !rhs.sharding().IsReplicated()) { - return false; - } - if (lhs.has_sharding() && rhs.has_sharding() && - lhs.sharding() != rhs.sharding()) { - return false; - } - return true; -} - } // namespace hlo_instruction_utils } // namespace xla diff --git a/xla/hlo/ir/hlo_instruction_utils.h b/xla/hlo/ir/hlo_instruction_utils.h index 563fb02736ad26..3721f0e65b3200 100644 --- a/xla/hlo/ir/hlo_instruction_utils.h +++ b/xla/hlo/ir/hlo_instruction_utils.h @@ -31,10 +31,6 @@ void AddOrUpdateVectorOfPairsAsAttribute( HloInstruction* instr, std::string attr_name, std::vector> intervals); -// Check if two shardings are equivalent. -bool HasEquivalentShardings(const HloInstruction& lhs, - const HloInstruction& rhs); - } // namespace hlo_instruction_utils } // namespace xla diff --git a/xla/hlo/transforms/BUILD b/xla/hlo/transforms/BUILD index 8f959acab9cc96..65d2368f0d0453 100644 --- a/xla/hlo/transforms/BUILD +++ b/xla/hlo/transforms/BUILD @@ -17,42 +17,6 @@ package_group( ], ) -cc_library( - name = "hlo_broadcast_splitter", - srcs = ["hlo_broadcast_splitter.cc"], - hdrs = ["hlo_broadcast_splitter.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "hlo_broadcast_splitter_test", - srcs = ["hlo_broadcast_splitter_test.cc"], - deps = [ - ":hlo_broadcast_splitter", - "//xla:test", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:statusor", - ], -) - cc_library( name = "hlo_constant_splitter", srcs = ["hlo_constant_splitter.cc"], diff --git a/xla/hlo/transforms/hlo_broadcast_splitter.cc b/xla/hlo/transforms/hlo_broadcast_splitter.cc deleted file mode 100644 index 8d24e84cfd98dc..00000000000000 --- a/xla/hlo/transforms/hlo_broadcast_splitter.cc +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/hlo/transforms/hlo_broadcast_splitter.h" - -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "tsl/platform/errors.h" - -namespace xla { - -absl::StatusOr HloBroadcastSplitter::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - bool changed = false; - absl::flat_hash_set seen_broadcast_operands; - for (HloComputation* computation : module->computations(execution_threads)) { - for (HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - for (int64_t i = 0; i < instruction->operand_count(); ++i) { - HloInstruction* operand = instruction->mutable_operand(i); - if (operand->opcode() == HloOpcode::kBroadcast) { - if (seen_broadcast_operands.contains(operand)) { - HloInstruction* cloned_broadcast = - operand->AddInstruction(operand->Clone()); - TF_RETURN_IF_ERROR( - operand->ReplaceUseWith(instruction, i, cloned_broadcast)); - } else { - seen_broadcast_operands.insert(operand); - } - } - } - } - } - return changed; -} - -} // namespace xla diff --git a/xla/hlo/transforms/hlo_broadcast_splitter.h b/xla/hlo/transforms/hlo_broadcast_splitter.h deleted file mode 100644 index d7136a01d92ff2..00000000000000 --- a/xla/hlo/transforms/hlo_broadcast_splitter.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_HLO_TRANSFORMS_HLO_BROADCAST_SPLITTER_H_ -#define XLA_HLO_TRANSFORMS_HLO_BROADCAST_SPLITTER_H_ - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" - -namespace xla { - -// Splits the broadcast instructions such that they have a single user. This -// aggressively duplicates all broadcasts and relies on DCE to clean up the -// duplicates after propagation and partitioning. -class HloBroadcastSplitter : public HloModulePass { - public: - HloBroadcastSplitter() = default; - absl::string_view name() const override { return "hlo-broadcast-splitter"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla - -#endif // XLA_HLO_TRANSFORMS_HLO_BROADCAST_SPLITTER_H_ diff --git a/xla/hlo/transforms/hlo_broadcast_splitter_test.cc b/xla/hlo/transforms/hlo_broadcast_splitter_test.cc deleted file mode 100644 index 0105f4a1ba52d6..00000000000000 --- a/xla/hlo/transforms/hlo_broadcast_splitter_test.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/hlo/transforms/hlo_broadcast_splitter.h" - -#include -#include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_parser.h" -#include "xla/test.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace { - -using HloBroadcastSplitterTest = HloTestBase; - -TEST_F(HloBroadcastSplitterTest, SplitBroadcast) { - const char* module_str = R"( - HloModule test_module - - ENTRY entry_computation { - param = (f32[], f32[1024,1024], f32[1024,1024]) parameter(0), - sharding={{replicated}, {devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}, {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}} - gte0 = f32[] get-tuple-element(param), index=0 - gte1 = f32[1024,1024] get-tuple-element(param), index=1 - gte2 = f32[1024,1024] get-tuple-element(param), index=2 - broadcast = f32[1024,1024] broadcast(gte0), dimensions={} - add1 = f32[1024,1024] add(broadcast, gte1) - add2 = f32[1024,1024] add(broadcast, gte2) - ROOT root = (f32[1024,1024], f32[1024,1024]) tuple(add1, add2) - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(module_str)); - TF_ASSERT_OK(HloBroadcastSplitter().Run(module.get()).status()); - - VLOG(1) << module->ToString(); - // Check that every broadcast has at most one user. - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBroadcast) { - EXPECT_LE(instruction->user_count(), 1); - } - } - } -} - -TEST_F(HloBroadcastSplitterTest, SplitBroadcastWithinWhileLoop) { - const char* module_str = R"( - -%cond { - %vars.cond = (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) parameter(0) - %count.cond = s32[] get-tuple-element(%vars.cond), index=0 - %limit = s32[] constant(10) - ROOT %lt = pred[] compare(%count.cond, %limit), direction=LT -} - -%body { - %param = (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) parameter(0) - %count = s32[] get-tuple-element(%param), index=0 - %broadcast1 = f32[1024,1024] get-tuple-element(%param), index=1 - %lhs = f32[1024,1024] get-tuple-element(%param), index=2 - %broadcast2 = f32[1024,1024] get-tuple-element(%param), index=3 - %rhs = f32[1024,1024] get-tuple-element(%param), index=4 - add1 = f32[1024,1024] add(broadcast1, lhs) - add2 = f32[1024,1024] add(broadcast2, rhs) - ROOT %tuple = (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) tuple(%count, %broadcast1, %add1, %broadcast2, %add2) -} - -ENTRY %entry { - param = (f32[], f32[1024,1024], f32[1024,1024]) parameter(0), - sharding={{replicated}, {devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}, {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}} - gte0 = f32[] get-tuple-element(param), index=0 - gte1 = f32[1024,1024] get-tuple-element(param), index=1 - gte2 = f32[1024,1024] get-tuple-element(param), index=2 - broadcast = f32[1024,1024] broadcast(gte0), dimensions={} - zero = s32[] constant(0) - tuple = (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) tuple(zero, broadcast, gte1, broadcast, gte2) - while = (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) while(%tuple), body=%body, condition=%cond - ROOT %copy = (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) copy(%while) -})"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(module_str)); - TF_ASSERT_OK(HloBroadcastSplitter().Run(module.get()).status()); - - VLOG(1) << module->ToString(); - // Check that the broadcast are duplicated for multiple usage in the same - // user. - absl::flat_hash_set broadcasts; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBroadcast) { - EXPECT_FALSE(broadcasts.contains(instruction)); - broadcasts.insert(instruction); - } - } - } -} - -} // namespace -} // namespace xla diff --git a/xla/service/BUILD b/xla/service/BUILD index 75b2c53fe3fd91..73db9ba1a81a21 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -600,6 +600,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_instruction_utils", "//xla/hlo/utils:hlo_query", "//xla/service:hlo_parser", "@com_google_absl//absl/algorithm:container", @@ -803,10 +804,8 @@ cc_library( ":call_graph", ":custom_call_sharding_helper", ":dot_as_convolution_util", - ":hlo_cse", ":hlo_graph_dumper", ":hlo_pass", - ":hlo_pass_pipeline", ":host_memory_offload_annotations_hdr", "//xla:array", "//xla:protobuf_util", @@ -2715,6 +2714,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_instruction_utils", "//xla/hlo/utils:hlo_sharding_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -5489,8 +5489,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_sharding_util", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:errors", diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 93dd425ac6d8e8..404386ce015a55 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -5615,6 +5615,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_instruction_utils", "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", "//xla/service:flatten_call_graph", diff --git a/xla/service/hlo_cse.cc b/xla/service/hlo_cse.cc index 8f7810c9be2772..2594fa392a5c1c 100644 --- a/xla/service/hlo_cse.cc +++ b/xla/service/hlo_cse.cc @@ -15,23 +15,18 @@ limitations under the License. #include "xla/service/hlo_cse.h" -#include #include #include #include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instruction_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_sharding.h" -#include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/literal.h" #include "xla/service/hlo_domain_map.h" #include "xla/shape_util.h" @@ -41,18 +36,11 @@ namespace xla { namespace { -template +template struct ConstantKey { template friend H AbslHashValue(H h, const ConstantKey& key) { h = H::combine(std::move(h), key.domain); - if (kIsShardingSensitive) { - if (key.hlo->has_sharding()) { - h = H::combine(std::move(h), key.hlo->sharding()); - } else { - h = H::combine(std::move(h), HloSharding::Replicate()); - } - } return Literal::Hash( std::move(h), key.hlo->literal()); } @@ -61,9 +49,7 @@ struct ConstantKey { (kIsLayoutSensitive ? Shape::Equal() : Shape::Equal().IgnoreLayout())( lhs.hlo->shape(), rhs.hlo->shape()) && - lhs.hlo->literal().Equal(rhs.hlo->literal(), kIsLayoutSensitive) && - (!kIsShardingSensitive || - hlo_instruction_utils::HasEquivalentShardings(*lhs.hlo, *rhs.hlo)); + lhs.hlo->literal().Equal(rhs.hlo->literal(), kIsLayoutSensitive); } HloConstantInstruction* hlo; int64_t domain; @@ -74,7 +60,7 @@ struct ConstantKey { // // While we're here, also combine identical iota instructions, since they need // similar treatment. -template +template absl::StatusOr CombineConstants(HloComputation* computation, bool only_scalars) { // Populating the domain map is somewhat expensive -- only do it if there are @@ -91,8 +77,7 @@ absl::StatusOr CombineConstants(HloComputation* computation, // Map from the literal hash of a constant or the shape hash of an iota all // equivalent instructions. This avoids extreme quadratic behavior with many // scalar constants. - absl::flat_hash_set> - constants; + absl::flat_hash_set> constants; int64_t combined = 0; auto inst_it = computation->instructions().begin(); while (inst_it != computation->instructions().end()) { @@ -108,11 +93,9 @@ absl::StatusOr CombineConstants(HloComputation* computation, HloInstruction* match = nullptr; if (auto* constant_inst = DynCast(instruction)) { - auto insert_result = constants.insert( - ConstantKey{ - constant_inst, - (domain_map != nullptr ? domain_map->GetDomainId(instruction) - : 0)}); + auto insert_result = constants.insert(ConstantKey{ + constant_inst, + (domain_map != nullptr ? domain_map->GetDomainId(instruction) : 0)}); if (!insert_result.second) { match = insert_result.first->hlo; } @@ -262,26 +245,9 @@ absl::StatusOr HloCSE::Run( }; auto cse_equal = [&](const CseKey& lhs, const CseKey& rhs) { - if (lhs.hlo->IdenticalIgnoringCommutativeOperandOrder( - *rhs.hlo, eq_instructions, eq_computations, is_layout_sensitive_, - /*sharding_sensitive=*/false)) { - bool equal = true; - // Check if the shardings are equal or compatible. - if (is_sharding_sensitive_) { - equal = - hlo_instruction_utils::HasEquivalentShardings(*lhs.hlo, *rhs.hlo); - if (!equal && allow_compatible_sharding_ && lhs.hlo->has_sharding() && - rhs.hlo->has_sharding()) { - HloSharding lhs_sharding = lhs.hlo->sharding(); - equal |= (hlo_sharding_util::IsSubTilingOrEqualSharding( - lhs.hlo->shape(), lhs_sharding, rhs.hlo->sharding()) || - hlo_sharding_util::IsSubTilingOrEqualSharding( - lhs.hlo->shape(), rhs.hlo->sharding(), lhs_sharding)); - } - } - return equal; - } - return false; + return lhs.hlo->IdenticalIgnoringCommutativeOperandOrder( + *rhs.hlo, eq_instructions, eq_computations, is_layout_sensitive_, + /*sharding_sensitive=*/true); }; for (auto* computation : module->computations(execution_threads)) { @@ -289,20 +255,11 @@ absl::StatusOr HloCSE::Run( continue; } - bool combined; - if (is_layout_sensitive_ && is_sharding_sensitive_) { - combined = - CombineConstants(computation, only_scalars_).value(); - } else if (is_layout_sensitive_ && !is_sharding_sensitive_) { - combined = - CombineConstants(computation, only_scalars_).value(); - } else if (!is_layout_sensitive_ && is_sharding_sensitive_) { - combined = - CombineConstants(computation, only_scalars_).value(); - } else { - combined = - CombineConstants(computation, only_scalars_).value(); - } + TF_ASSIGN_OR_RETURN( + bool combined, + is_layout_sensitive_ + ? CombineConstants(computation, only_scalars_) + : CombineConstants(computation, only_scalars_)); changed |= combined; // HLO instructions are grouped into equivalency classes by using the @@ -331,17 +288,6 @@ absl::StatusOr HloCSE::Run( auto pair = representatives.insert(CseKey{instruction}); if (!pair.second) { HloInstruction* equivalent_instruction = pair.first->hlo; - if (is_sharding_sensitive_ && allow_compatible_sharding_ && - instruction->has_sharding() && - equivalent_instruction->has_sharding() && - instruction->sharding() != equivalent_instruction->sharding()) { - if (hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), instruction->sharding(), - equivalent_instruction->sharding())) { - equivalent_instruction->set_sharding(instruction->sharding()); - } - } - TF_RETURN_IF_ERROR( instruction->ReplaceAllUsesWith(equivalent_instruction)); TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( diff --git a/xla/service/hlo_cse.h b/xla/service/hlo_cse.h index 03496234c007c8..1ccab0d5872eb0 100644 --- a/xla/service/hlo_cse.h +++ b/xla/service/hlo_cse.h @@ -34,14 +34,11 @@ class HloCSE : public HloModulePass { explicit HloCSE(bool is_layout_sensitive, bool only_fusion_computations = false, bool ignore_control_dependencies = false, - bool only_scalars = false, bool is_sharding_sensitive = true, - bool allow_compatible_sharding = false) + bool only_scalars = false) : is_layout_sensitive_(is_layout_sensitive), only_fusion_computations_(only_fusion_computations), ignore_control_dependencies_(ignore_control_dependencies), - only_scalars_(only_scalars), - is_sharding_sensitive_(is_sharding_sensitive), - allow_compatible_sharding_(allow_compatible_sharding) {} + only_scalars_(only_scalars) {} ~HloCSE() override = default; absl::string_view name() const override { return "cse"; } @@ -57,8 +54,6 @@ class HloCSE : public HloModulePass { const bool only_fusion_computations_; const bool ignore_control_dependencies_; const bool only_scalars_; - const bool is_sharding_sensitive_; - const bool allow_compatible_sharding_; }; } // namespace xla diff --git a/xla/service/hlo_cse_test.cc b/xla/service/hlo_cse_test.cc index c08902d9826b00..106eea0923b0be 100644 --- a/xla/service/hlo_cse_test.cc +++ b/xla/service/hlo_cse_test.cc @@ -937,108 +937,6 @@ TEST_F(HloCseTest, MultiOutputFusion) { EXPECT_EQ(add0, add1); } -TEST_F(HloCseTest, CombineWithCompatibleShardings) { - const char* const hlo_string = R"( -HloModule module_entry - -%body (param: (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024])) -> (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) { - %param = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) parameter(0) - %count = s32[] get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=0 - %broadcast1 = f32[1024,1024]{1,0} get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=1 - %lhs = f32[1024,1024]{1,0} get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=2 - %add1 = f32[1024,1024]{1,0} add(f32[1024,1024]{1,0} %broadcast1, f32[1024,1024]{1,0} %lhs) - %broadcast2 = f32[1024,1024]{1,0} get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=3 - %rhs = f32[1024,1024]{1,0} get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=4 - %add2 = f32[1024,1024]{1,0} add(f32[1024,1024]{1,0} %broadcast2, f32[1024,1024]{1,0} %rhs) - ROOT %tuple = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) tuple(s32[] %count, f32[1024,1024]{1,0} %broadcast1, f32[1024,1024]{1,0} %add1, f32[1024,1024]{1,0} %broadcast2, f32[1024,1024]{1,0} %add2) -} - -%cond (vars.cond: (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024])) -> pred[] { - %vars.cond = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) parameter(0) - %count.cond = s32[] get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %vars.cond), index=0 - %limit = s32[] constant(10) - ROOT %lt = pred[] compare(s32[] %count.cond, s32[] %limit), direction=LT -} - -ENTRY %entry (param.1: (f32[], f32[1024,1024], f32[1024,1024])) -> (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) { - %param.1 = (f32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) parameter(0), sharding={{replicated}, {devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}, {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}} - %gte0 = f32[] get-tuple-element((f32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param.1), index=0 - %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %gte0), dimensions={} - %zero = s32[] constant(0) - %broadcast.clone = f32[1024,1024]{1,0} broadcast(f32[] %gte0), dimensions={}, sharding={devices=[2,2]0,1,2,3} - %gte1 = f32[1024,1024]{1,0} get-tuple-element((f32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param.1), index=1 - %broadcast.clone.1 = f32[1024,1024]{1,0} broadcast(f32[] %gte0), dimensions={}, sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} - %gte2 = f32[1024,1024]{1,0} get-tuple-element((f32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param.1), index=2 - %tuple.1 = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) tuple(s32[] %zero, f32[1024,1024]{1,0} %broadcast.clone, f32[1024,1024]{1,0} %gte1, f32[1024,1024]{1,0} %broadcast.clone.1, f32[1024,1024]{1,0} %gte2) - %while = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) while((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %tuple.1), condition=%cond, body=%body - ROOT %copy = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) copy((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %while) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloCSE cse(/*is_layout_sensitive=*/false, - /*only_fusion_computations=*/false, - /*ignore_control_dependencies=*/false, - /*only_scalars=*/false, - /*is_sharding_sensitive=*/true, - /*allow_compatible_sharding=*/true); - TF_ASSERT_OK_AND_ASSIGN(bool result, cse.Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(result); - auto root = module->entry_computation()->root_instruction(); - auto tuple = root->operand(0)->operand(0); - auto broadcast1 = tuple->operand(1); - auto broadcast2 = tuple->operand(3); - EXPECT_EQ(broadcast1, broadcast2); -} - -TEST_F(HloCseTest, DoNotCombineWithCompatibleShardings) { - const char* const hlo_string = R"( -HloModule module_entry - -%body (param: (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024])) -> (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) { - %param = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) parameter(0) - %count = s32[] get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=0 - %broadcast1 = f32[1024,1024]{1,0} get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=1 - %lhs = f32[1024,1024]{1,0} get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=2 - %add1 = f32[1024,1024]{1,0} add(f32[1024,1024]{1,0} %broadcast1, f32[1024,1024]{1,0} %lhs) - %broadcast2 = f32[1024,1024]{1,0} get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=3 - %rhs = f32[1024,1024]{1,0} get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param), index=4 - %add2 = f32[1024,1024]{1,0} add(f32[1024,1024]{1,0} %broadcast2, f32[1024,1024]{1,0} %rhs) - ROOT %tuple = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) tuple(s32[] %count, f32[1024,1024]{1,0} %broadcast1, f32[1024,1024]{1,0} %add1, f32[1024,1024]{1,0} %broadcast2, f32[1024,1024]{1,0} %add2) -} - -%cond (vars.cond: (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024])) -> pred[] { - %vars.cond = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) parameter(0) - %count.cond = s32[] get-tuple-element((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %vars.cond), index=0 - %limit = s32[] constant(10) - ROOT %lt = pred[] compare(s32[] %count.cond, s32[] %limit), direction=LT -} - -ENTRY %entry (param.1: (f32[], f32[1024,1024], f32[1024,1024])) -> (s32[], f32[1024,1024], f32[1024,1024], f32[1024,1024], f32[1024,1024]) { - %param.1 = (f32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) parameter(0), sharding={{replicated}, {devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}, {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}} - %gte0 = f32[] get-tuple-element((f32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param.1), index=0 - %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %gte0), dimensions={} - %zero = s32[] constant(0) - %broadcast.clone = f32[1024,1024]{1,0} broadcast(f32[] %gte0), dimensions={}, sharding={devices=[2,2]0,1,2,3} - %gte1 = f32[1024,1024]{1,0} get-tuple-element((f32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param.1), index=1 - %broadcast.clone.1 = f32[1024,1024]{1,0} broadcast(f32[] %gte0), dimensions={}, sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate} - %gte2 = f32[1024,1024]{1,0} get-tuple-element((f32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %param.1), index=2 - %tuple.1 = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) tuple(s32[] %zero, f32[1024,1024]{1,0} %broadcast.clone, f32[1024,1024]{1,0} %gte1, f32[1024,1024]{1,0} %broadcast.clone.1, f32[1024,1024]{1,0} %gte2) - %while = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) while((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %tuple.1), condition=%cond, body=%body - ROOT %copy = (s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) copy((s32[], f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) %while) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloCSE cse(/*is_layout_sensitive=*/false, - /*only_fusion_computations=*/false, - /*ignore_control_dependencies=*/false, - /*only_scalars=*/false, - /*is_sharding_sensitive=*/true); - TF_ASSERT_OK_AND_ASSIGN(bool result, cse.Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_FALSE(result); -} - class HloCseCommutativeOpTest : public HloCseTest, public ::testing::WithParamInterface {}; diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index d6fac61352a6b0..91b291e7d6156d 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -37,7 +36,6 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -50,8 +48,6 @@ limitations under the License. #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/protobuf_util.h" #include "xla/service/dot_as_convolution_util.h" -#include "xla/service/hlo_cse.h" -#include "xla/service/hlo_pass_pipeline.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/spmd/shard_barrier_partitioner.h" #include "xla/shape.h" @@ -68,94 +64,6 @@ limitations under the License. namespace xla { namespace { -// Remove stale shard group instructions after a module has been changed. -void RemoveStaleShardGroupInstructions( - HloModule* module, - const absl::flat_hash_set& execution_threads, - absl::flat_hash_map>& - shard_group_id_to_shard_group) { - absl::flat_hash_map - instruction_to_shard_group_id; - absl::flat_hash_set not_stale; - for (auto& [shard_group_id, shard_group] : shard_group_id_to_shard_group) { - for (HloInstruction* instruction : shard_group) { - instruction_to_shard_group_id[instruction] = shard_group_id; - } - } - for (auto computation : module->computations(execution_threads)) { - for (auto instruction : computation->instructions()) { - if (instruction_to_shard_group_id.contains(instruction)) { - not_stale.insert(instruction); - } - } - } - for (auto& [instruction, shard_group_id] : instruction_to_shard_group_id) { - if (!not_stale.contains(instruction)) { - shard_group_id_to_shard_group[shard_group_id].erase(instruction); - } - } -} - -template -using IsHloInstructionPointer = - typename std::enable_if_t || - std::is_same_v>; - -template > -void RemoveStaleSetInstructions( - HloModule* module, - const absl::flat_hash_set& execution_threads, - absl::flat_hash_set& set) { - absl::flat_hash_set not_stale; - for (auto computation : module->computations(execution_threads)) { - for (auto instruction : computation->instructions()) { - if (set.contains(instruction)) { - not_stale.insert(instruction); - } - } - } - set = std::move(not_stale); -} - -template > -void RemoveStaleMapInstructions( - HloModule* module, - const absl::flat_hash_set& execution_threads, - absl::flat_hash_map& map) { - absl::flat_hash_set not_stale; - for (auto computation : module->computations(execution_threads)) { - for (auto instruction : computation->instructions()) { - if (map.contains(instruction)) { - not_stale.insert(instruction); - } - } - } - absl::erase_if(map, [¬_stale](const auto& p) { - return !not_stale.contains(p.first); - }); -} - -void RemoveStaleComputationMap( - HloModule* module, - const absl::flat_hash_set& execution_threads, - ShardingPropagation::ComputationMap& computation_map) { - absl::flat_hash_set computation_map_instructions; - for (auto& [_, instruction] : computation_map) { - computation_map_instructions.insert(instruction); - } - absl::flat_hash_set not_stale; - for (auto computation : module->computations(execution_threads)) { - for (auto instruction : computation->instructions()) { - if (computation_map_instructions.contains(instruction)) { - not_stale.insert(instruction); - } - } - } - absl::erase_if(computation_map, [¬_stale](const auto& p) { - return !not_stale.contains(p.second); - }); -} - // Returning the improved sharding of an instruction from some other sharding. std::optional ReturnImprovedSharding( HloSharding sharding, HloInstruction* instruction, @@ -3375,39 +3283,6 @@ absl::StatusOr ShardingPropagation::Run( run_to_fix_point(aggressiveness, /*propagate_shard_group=*/true)); } - if (changed) { - // Run CSE again to remove any duplicate ops with the same sharding or - // compatible shardings. - HloPassPipeline pass("sharding-propation-cse"); - pass.AddPass( - /*is_layout_sensitive=*/false, - /*only_fusion_computations=*/false, - /*ignore_control_dependencies=*/false, - /*only_scalars=*/false, - /*is_sharding_sensitive=*/true, - /*allow_compatible_sharding=*/false); - TF_RETURN_IF_ERROR(pass.Run(module, execution_threads).status()); - - // CSE may invalidate stored HloInstruction pointers, so we need to remove - // stale shard group instructions. - call_graph = CallGraph::Build(module); - RemoveStaleShardGroupInstructions(module, execution_threads, - shard_group_id_to_shard_as_group); - RemoveStaleShardGroupInstructions(module, execution_threads, - shard_group_id_to_shard_like_group); - RemoveStaleSetInstructions(module, execution_threads, provided_shardings); - RemoveStaleMapInstructions(module, execution_threads, - instruction_to_shard_group_id); - RemoveStaleMapInstructions(module, execution_threads, unspecified_dims); - RemoveStaleComputationMap(module, execution_threads, computation_map); - if (cse_prevention_only_) { - RemoveStaleMapInstructions(module, execution_threads, *original_sharding); - } - // propagate sharding again to update the sharding of the CSE'd ops. - TF_RETURN_IF_ERROR(run_to_fix_point(/*aggressiveness=*/3, - /*propagate_shard_group=*/false)); - } - // Align the shardings from the same shard_as group so that they will adopt // the same sharding. for (const auto& [shard_as_group_id, shard_as_group] : diff --git a/xla/service/sharding_propagation_test.cc b/xla/service/sharding_propagation_test.cc index 76977b282c28b5..58a5d9f3ea5531 100644 --- a/xla/service/sharding_propagation_test.cc +++ b/xla/service/sharding_propagation_test.cc @@ -12003,10 +12003,9 @@ HloModule pjit_f ENTRY main.11 { Arg_0.1 = bf16[384,1408]{1,0} parameter(0), sharding={devices=[1,16,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate} - Arg_0.2 = bf16[384,1408]{1,0} parameter(1), sharding={devices=[1,16,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate} broadcast.4 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2} custom-call.5 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.4), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1} - broadcast.2 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.2), dimensions={1,2} + broadcast.2 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2} custom-call.3 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.2), custom_call_target="Sharding", sharding={devices=[8,1,1,1024]<=[8192] last_tile_dim_replicate}, backend_config="unspecified_dims=[1,2]" custom-call.6 = bf16[8,384,1408]{2,1,0} custom-call(custom-call.3), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1} %shard-barrier-to = bf16[8,384,1408]{2,1,0} custom-call(%custom-call.6), custom_call_target="ShardBarrierTo", custom_call_has_side_effect=true From 57365f440d9afb2e9a10c431c5074c998b1fe2c7 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Fri, 19 Jul 2024 12:26:00 -0700 Subject: [PATCH 028/376] Create a base class for GpuTimer, and a method to create objects of that class on Stream. PiperOrigin-RevId: 654089412 --- xla/stream_executor/BUILD | 11 +++ xla/stream_executor/cuda/BUILD | 1 + xla/stream_executor/cuda/cuda_executor.cc | 7 ++ xla/stream_executor/event_based_timer.h | 38 +++++++++++ xla/stream_executor/gpu/BUILD | 8 +++ xla/stream_executor/gpu/gpu_executor.h | 6 ++ xla/stream_executor/gpu/gpu_stream.cc | 7 ++ xla/stream_executor/gpu/gpu_stream.h | 4 ++ xla/stream_executor/gpu/gpu_timer.cc | 83 +++++++++++++++-------- xla/stream_executor/gpu/gpu_timer.h | 14 ++-- xla/stream_executor/rocm/rocm_executor.cc | 6 ++ xla/stream_executor/stream.h | 17 +++++ 12 files changed, 168 insertions(+), 34 deletions(-) create mode 100644 xla/stream_executor/event_based_timer.h diff --git a/xla/stream_executor/BUILD b/xla/stream_executor/BUILD index 5631a5e6ba593c..2c084fee8894d4 100644 --- a/xla/stream_executor/BUILD +++ b/xla/stream_executor/BUILD @@ -91,6 +91,7 @@ cc_library( ":device_description_proto_cc", ":dnn", ":event", + ":event_based_timer", ":fft", ":host_memory_allocation", # build_cleaner: keep ":host_or_device_scalar", @@ -476,6 +477,7 @@ cc_library( ":device_description", ":device_memory", ":event", + ":event_based_timer", ":kernel", ":launch_dim", ":platform", @@ -523,6 +525,15 @@ cc_library( deps = ["@com_google_absl//absl/strings:str_format"], ) +cc_library( + name = "event_based_timer", + hdrs = ["event_based_timer.h"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + ], +) + cc_library( name = "command_buffer", hdrs = ["command_buffer.h"], diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 917785c848ecc4..03b5db71e0deeb 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -759,6 +759,7 @@ cuda_only_cc_library( "//xla/stream_executor:command_buffer", "//xla/stream_executor:dnn", "//xla/stream_executor:event", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:fft", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:module_spec", diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index bda62ea3a97300..1aa01f5469aa2f 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/kernel_spec.h" @@ -265,6 +266,12 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, return absl::OkStatus(); } +absl::StatusOr> +GpuExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { + // TODO(b/301020144) Move this all to the appropriate Executor class. + return GpuTimer::CreateEventBasedTimer(stream, use_delay_kernel); +} + bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { auto module_it = gpu_binary_to_module_.find(gpu_binary); if (gpu_binary_to_module_.end() == module_it) { diff --git a/xla/stream_executor/event_based_timer.h b/xla/stream_executor/event_based_timer.h new file mode 100644 index 00000000000000..2283f34619cff5 --- /dev/null +++ b/xla/stream_executor/event_based_timer.h @@ -0,0 +1,38 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_EVENT_BASED_TIMER_H_ +#define XLA_STREAM_EXECUTOR_EVENT_BASED_TIMER_H_ + +#include "absl/status/statusor.h" +#include "absl/time/time.h" + +namespace stream_executor { + +// This class defines an interface for an Event-based timer. It allows the +// timing via Events from the creation of an EventBasedTimer to some arbitrary +// later point when the GetElapsedDuration method is called. +class EventBasedTimer { + public: + virtual ~EventBasedTimer() = default; + + // Stops the timer on the first call and returns the elapsed duration. + // Subsequent calls error out. + virtual absl::StatusOr GetElapsedDuration() = 0; +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_EVENT_BASED_TIMER_H_ diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 307eb68ab442eb..8175356bada4be 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -215,6 +215,7 @@ gpu_only_cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "//xla/stream_executor:event", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:fft", "//xla/stream_executor:host_memory_allocation", "//xla/stream_executor:kernel", @@ -311,6 +312,7 @@ gpu_only_cc_library( ":gpu_types_header", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", @@ -331,6 +333,7 @@ gpu_only_cc_library( ":gpu_types_header", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", @@ -365,6 +368,9 @@ gpu_kernel_library( "gpu_timer_kernel_cuda.cu.cc", ], tags = ["manual"], + visibility = internal_visibility([ + "//xla/stream_executor:__subpackages__", + ]), deps = [ ":gpu_driver_header", ":gpu_executor_header", @@ -408,6 +414,8 @@ gpu_only_cc_library( ":gpu_stream", ":gpu_types_header", "//xla/stream_executor", + "//xla/stream_executor:event", + "//xla/stream_executor:event_based_timer", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index 65e53359898372..275f0c99e87ea0 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -1,3 +1,4 @@ +#include "xla/stream_executor/event_based_timer.h" /* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -67,6 +68,7 @@ namespace gpu { class GpuKernel; class GpuCommandBuffer; +class GpuStream; // CUDA-platform implementation of the platform-agnostic // StreamExecutor. @@ -299,6 +301,10 @@ class GpuExecutor : public StreamExecutorCommon { uint64_t GetArgumentLoggingMode() const { return argument_logging_mode_; } + // Creates an EventBasedTimer for the given stream. + absl::StatusOr> CreateEventBasedTimer( + GpuStream* stream, bool use_delay_kernel); + private: // Collects metadata for the specified kernel. absl::Status GetKernelMetadata(GpuKernel* cuda_kernel, diff --git a/xla/stream_executor/gpu/gpu_stream.cc b/xla/stream_executor/gpu/gpu_stream.cc index 557ea215220e1b..e1c943a1171d95 100644 --- a/xla/stream_executor/gpu/gpu_stream.cc +++ b/xla/stream_executor/gpu/gpu_stream.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_stream.h" #include +#include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" @@ -193,6 +195,11 @@ void GpuStream::set_name(absl::string_view name) { reinterpret_cast(gpu_stream()), name_); } +absl::StatusOr> +GpuStream::CreateEventBasedTimer(bool use_delay_kernel) { + return parent_->CreateEventBasedTimer(this, use_delay_kernel); +} + GpuStream* AsGpuStream(Stream* stream) { DCHECK(stream != nullptr); return static_cast(stream); diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index c855a78f306172..984084e062b636 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -20,6 +20,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_GPU_GPU_STREAM_H_ #include +#include #include #include "absl/functional/any_invocable.h" @@ -27,6 +28,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" @@ -115,6 +117,8 @@ class GpuStream : public StreamCommon { absl::AnyInvocable callback) override; void set_name(absl::string_view name) override; + absl::StatusOr> CreateEventBasedTimer( + bool use_delay_kernel) override; private: GpuExecutor* parent_; // Executor that spawned this stream. diff --git a/xla/stream_executor/gpu/gpu_timer.cc b/xla/stream_executor/gpu/gpu_timer.cc index d802bd8545a28b..1fa2f2c4f4bc04 100644 --- a/xla/stream_executor/gpu/gpu_timer.cc +++ b/xla/stream_executor/gpu/gpu_timer.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include +#include #include #include #include @@ -31,6 +31,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/utility/utility.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" @@ -65,6 +66,37 @@ bool ShouldLaunchDelayKernel() { return value; } +absl::Status CreateGpuTimerParts(Stream* real_stream, bool use_delay_kernel, + GpuExecutor*& parent, + GpuEventHandle& start_event, + GpuEventHandle& stop_event, + GpuSemaphore& semaphore) { + GpuStream* stream = AsGpuStream(real_stream); + parent = stream->parent(); + GpuContext* context = parent->gpu_context(); + TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &start_event, + GpuDriver::EventFlags::kDefault)); + TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &stop_event, + GpuDriver::EventFlags::kDefault)); + CHECK(start_event != nullptr && stop_event != nullptr); + if (!use_delay_kernel) { + LOG(WARNING) + << "Skipping the delay kernel, measurement accuracy will be reduced"; + } + + if (use_delay_kernel && ShouldLaunchDelayKernel()) { + TF_ASSIGN_OR_RETURN(bool is_supported, DelayKernelIsSupported(stream)); + + if (is_supported) { + TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(real_stream)); + } + } + + // The start event goes after the delay kernel in the stream + TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent->gpu_context(), start_event, + stream->gpu_stream())); + return absl::OkStatus(); +} } // namespace /*deprecated*/ /*static*/ absl::StatusOr GpuTimer::Create( @@ -88,35 +120,30 @@ bool ShouldLaunchDelayKernel() { /*static*/ absl::StatusOr GpuTimer::Create(Stream* real_stream, bool use_delay_kernel) { - GpuStream* stream = AsGpuStream(real_stream); - GpuExecutor* parent = stream->parent(); - GpuContext* context = parent->gpu_context(); - GpuEventHandle start_event; - TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &start_event, - GpuDriver::EventFlags::kDefault)); - GpuEventHandle stop_event; - TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &stop_event, - GpuDriver::EventFlags::kDefault)); - CHECK(start_event != nullptr && stop_event != nullptr); + GpuExecutor* parent = nullptr; + GpuEventHandle start_event = nullptr; + GpuEventHandle stop_event = nullptr; GpuSemaphore semaphore{}; - if (!use_delay_kernel) { - LOG(WARNING) - << "Skipping the delay kernel, measurement accuracy will be reduced"; - } - - if (use_delay_kernel && ShouldLaunchDelayKernel()) { - TF_ASSIGN_OR_RETURN(bool is_supported, DelayKernelIsSupported(stream)); - - if (is_supported) { - TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(real_stream)); - } - } + TF_RETURN_IF_ERROR(CreateGpuTimerParts(real_stream, use_delay_kernel, parent, + start_event, stop_event, semaphore)); + return absl::StatusOr{absl::in_place, + parent, + start_event, + stop_event, + AsGpuStream(real_stream), + std::move(semaphore)}; +} - // The start event goes after the delay kernel in the stream - TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent->gpu_context(), start_event, - stream->gpu_stream())); - return absl::StatusOr{absl::in_place, parent, start_event, - stop_event, stream, std::move(semaphore)}; +absl::StatusOr> +GpuTimer::CreateEventBasedTimer(Stream* stream, bool use_delay_kernel) { + GpuExecutor* parent = nullptr; + GpuEventHandle start_event = nullptr; + GpuEventHandle stop_event = nullptr; + GpuSemaphore semaphore{}; + TF_RETURN_IF_ERROR(CreateGpuTimerParts(stream, use_delay_kernel, parent, + start_event, stop_event, semaphore)); + return std::make_unique(parent, start_event, stop_event, + AsGpuStream(stream), std::move(semaphore)); } /*static*/ void GpuTimer::ReturnRandomDurationsForTesting() { diff --git a/xla/stream_executor/gpu/gpu_timer.h b/xla/stream_executor/gpu/gpu_timer.h index 955e5a203e5d58..ea8eaa852960f8 100644 --- a/xla/stream_executor/gpu/gpu_timer.h +++ b/xla/stream_executor/gpu/gpu_timer.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_H_ -#include +#include #include #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -44,9 +46,11 @@ class GpuStream; // an end event is queued and the delay kernel exits. This allows the device // execution time of the tasks queued to the stream while the timer is active // to be measured more accurately. -class GpuTimer { +class GpuTimer : public EventBasedTimer { public: static absl::StatusOr Create(Stream* stream, bool use_delay_kernel); + static absl::StatusOr> CreateEventBasedTimer( + Stream* stream, bool use_delay_kernel); [[deprecated("Pass Stream* not GpuStream*")]] static absl::StatusOr Create(GpuStream* stream); @@ -77,11 +81,9 @@ class GpuTimer { return *this; } - ~GpuTimer(); + ~GpuTimer() override; - // Stops the timer on the first call and returns the elapsed duration. - // Subsequent calls error out. - absl::StatusOr GetElapsedDuration(); + absl::StatusOr GetElapsedDuration() override; private: GpuExecutor* parent_; diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index d02c595a246dd2..8d5aa981419e59 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -184,6 +184,12 @@ GpuExecutor::CreateOrShareConstant(Stream* stream, return shared_constant; } +absl::StatusOr> +GpuExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { + // TODO(b/301020144) Move this all to the appropriate Executor class. + return GpuTimer::CreateEventBasedTimer(stream, use_delay_kernel); +} + bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { auto module_it = gpu_binary_to_module_.find(gpu_binary); if (gpu_binary_to_module_.end() == module_it) { diff --git a/xla/stream_executor/stream.h b/xla/stream_executor/stream.h index 66625500c02eb9..71cdf9a35b8da7 100644 --- a/xla/stream_executor/stream.h +++ b/xla/stream_executor/stream.h @@ -22,6 +22,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_STREAM_H_ #include +#include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" @@ -283,6 +285,21 @@ class Stream { // Get/set a name for a stream, which can be shown in profiling tools virtual absl::string_view name() const = 0; virtual void set_name(absl::string_view name) = 0; + + // Create an EventBasedTimer that can be used to time operations on this + // stream using Events. + // + // If use_delay_kernel is true, the timer will launch a delay kernel into the + // stream and queue a start event immediately afterwards. This delay kernel + // blocks execution on the stream until EventBasedTimer::GetElapsedDuration() + // is called, at which point an end event is queued and the delay kernel + // exits. This allows the device execution time of the tasks queued to the + // stream while the timer is active to be measured more accurately. + virtual absl::StatusOr> + CreateEventBasedTimer(bool use_delay_kernel) { + return absl::UnimplementedError( + "This stream does not support EventBasedTimers."); + } }; template From eb320cc876ec5b1ffc1d6da781b7cbe4aee92be0 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 19 Jul 2024 12:40:22 -0700 Subject: [PATCH 029/376] Fix comment of GetParticipatingDevicesGroups. The example was previously wrong. The example was already tested here https://github.com/openxla/xla/blob/6dfc6ace222f7e10a2c04c1fa0bcd856d066cce3/xla/service/collective_ops_utils_test.cc#L382 and now the example matches the test. Another example was also added. PiperOrigin-RevId: 654093983 --- xla/service/collective_ops_utils.h | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/xla/service/collective_ops_utils.h b/xla/service/collective_ops_utils.h index f06be498e3c233..c611d57a6e6264 100644 --- a/xla/service/collective_ops_utils.h +++ b/xla/service/collective_ops_utils.h @@ -131,11 +131,20 @@ absl::StatusOr GetCollectiveOpGroupMode( // // For example: // device_assignment={{33, 34}, {44, 45}, {55, 56}} 3 replicas 2 partitions -// group_mode=CollectiveOpGroupMode::kCrossReplica // replica_groups={{0}, {1, 2}} +// group_mode=CollectiveOpGroupMode::kCrossReplica +// +// This functions returns {{33}, {34}, {44, 45}, {55, 56}}. +// Partition 0 has 2 subgroups of participating devices {33}, {44, 55} and +// partition 1 has 2 subgroups of participating devices {34}, {45, 56}. +// +// Another example: +// device_assignment={{33, 34}, {44, 45}, {55, 56}} 3 replicas 2 partitions +// replica_groups={{0}, {1, 2}, {3, 4, 5}} +// group_mode=CollectiveOpGroupMode::kFlattenedID // -// This functions returns {{33, 34}, {44, 45, 55, 56}} -// There are 2 subgroups of participating devices {33, 34}, {44, 45, 55, 56}. +// This functions returns {{33}, {34, 44}, {45, 55, 56}}. The replica_ids map +// into a flattened version of device_assignment. absl::StatusOr>> GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment, absl::Span replica_groups, From d5659f40d5fa61bd1e70fdc78c53b06946ce097d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 19 Jul 2024 13:30:29 -0700 Subject: [PATCH 030/376] Reverts 15f346bd3b6d22d5f990dc7c0f4a02579a031487 PiperOrigin-RevId: 654108401 --- xla/service/sharding_propagation.cc | 21 +--- xla/service/sharding_propagation_test.cc | 130 ----------------------- 2 files changed, 1 insertion(+), 150 deletions(-) diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index 91b291e7d6156d..0a8e3cf14a42f4 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -1465,30 +1465,11 @@ absl::StatusOr ProcessShardingInstruction( for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { HloInstruction* instruction = *it; if (instruction->IsCustomCall("Sharding")) { + HloSharding original_sharding = instruction->sharding(); TF_RET_CHECK(instruction->has_sharding()) << "Sharding instruction must have a sharding attribute"; - HloSharding original_sharding = instruction->sharding(); VLOG(3) << "ProcessShardingInstruction: " << instruction->ToString(); - // Simplify consecutive Sharding custom-call instructions. If both - // shardings are tiled, we do not simplify the instruction since these - // two shardings can guide the partitioner. An example is - // https://github.com/google/jax/issues/21562. - HloInstruction* operand = instruction->mutable_operand(0); - if (!original_sharding.IsUnknown() && - operand->IsCustomCall("Sharding") && operand->user_count() == 1 && - !(original_sharding.IsTiled() && operand->sharding().IsTiled())) { - operand->set_sharding(original_sharding); - TF_ASSIGN_OR_RETURN( - std::ignore, - computation->ReplaceInstruction( - instruction, operand, /*preserve_sharding=*/false, - /*relay_control_dependency=*/false, - /*remove_unused_operands=*/false)); - changed = true; - continue; - } - std::vector unspec_dims; TF_RETURN_IF_ERROR(sharding_op_util::ParseAttributes( Cast(instruction)->opaque(), diff --git a/xla/service/sharding_propagation_test.cc b/xla/service/sharding_propagation_test.cc index 58a5d9f3ea5531..ac04389c805878 100644 --- a/xla/service/sharding_propagation_test.cc +++ b/xla/service/sharding_propagation_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/sharding_propagation.h" -#include #include #include #include @@ -12071,134 +12070,5 @@ ENTRY %elementwise { "last_tile_dim_replicate}}")); } -TEST_F(ShardingPropagationTest, RedundantShardingInstruction1) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %main.6 { - %p0 = f32[32,96] parameter(0), sharding={replicated} - %add.0 = f32[32,96] add(%p0, %p0) - %custom-call.0 = f32[32,96] custom-call(%add.0), custom_call_target="Sharding", sharding={replicated} - %custom-call.1 = f32[32,96] custom-call(%custom-call.0), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]} - ROOT %add.1 = f32[32,96] add(%custom-call.1, %custom-call.1) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, - /*allow_spmd_sharding_propagation_to_output=*/{true}) - .Run(module.get())); - EXPECT_TRUE(changed); - XLA_VLOG_LINES(1, module->ToString()); - - int64_t num_copy = 0; - for (const HloInstruction* instruction : - module->entry_computation()->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - EXPECT_THAT(instruction, op::Sharding("{devices=[2,2]<=[4]}")); - num_copy++; - } - } - EXPECT_EQ(num_copy, 1); -} - -TEST_F(ShardingPropagationTest, RedundantShardingInstruction2) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %main.6 { - %p0 = f32[32,96] parameter(0), sharding={replicated} - %add.0 = f32[32,96] add(%p0, %p0) - %custom-call.0 = f32[32,96] custom-call(%add.0), custom_call_target="Sharding", sharding={maximal device=0} - %custom-call.1 = f32[32,96] custom-call(%custom-call.0), custom_call_target="Sharding", sharding={maximal device=1} - %custom-call.2 = f32[32,96] custom-call(%custom-call.1), custom_call_target="Sharding", sharding={maximal device=2} - ROOT %add.1 = f32[32,96] add(%custom-call.2, %custom-call.2) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, - /*allow_spmd_sharding_propagation_to_output=*/{true}) - .Run(module.get())); - EXPECT_TRUE(changed); - XLA_VLOG_LINES(1, module->ToString()); - - int64_t num_copy = 0; - for (const HloInstruction* instruction : - module->entry_computation()->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - EXPECT_THAT(instruction, op::Sharding("{maximal device=2}")); - num_copy++; - } - } - EXPECT_EQ(num_copy, 1); -} - -TEST_F(ShardingPropagationTest, RedundantShardingInstruction3) { - // This target is similar to RedundantShardingInstruction1, except that - // %custom-call.0 has two users. - const char* const hlo_string = R"( -HloModule module - -ENTRY %main.6 { - %p0 = f32[32,96] parameter(0), sharding={replicated} - %add.0 = f32[32,96] add(%p0, %p0) - %custom-call.0 = f32[32,96] custom-call(%add.0), custom_call_target="Sharding", sharding={replicated} - %custom-call.1 = f32[32,96] custom-call(%custom-call.0), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]} - ROOT %add.1 = f32[32,96] add(%custom-call.0, %custom-call.1) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, - /*allow_spmd_sharding_propagation_to_output=*/{true}) - .Run(module.get())); - EXPECT_TRUE(changed); - XLA_VLOG_LINES(1, module->ToString()); - - int64_t num_copy = 0; - for (const HloInstruction* instruction : - module->entry_computation()->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - num_copy++; - } - } - EXPECT_EQ(num_copy, 2); -} - -TEST_F(ShardingPropagationTest, RedundantShardingInstruction4) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %main.6 { - %p0 = f32[32,96] parameter(0), sharding={replicated} - %add.0 = f32[32,96] add(%p0, %p0) - %custom-call.0 = f32[32,96] custom-call(%add.0), custom_call_target="Sharding", sharding={devices=[2,2]<=[2,2]T(1,0)} - %custom-call.1 = f32[32,96] custom-call(%custom-call.0), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]} - ROOT %add.1 = f32[32,96] add(%custom-call.1, %custom-call.1) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, - /*allow_spmd_sharding_propagation_to_output=*/{true}) - .Run(module.get())); - EXPECT_TRUE(changed); - XLA_VLOG_LINES(1, module->ToString()); - - int64_t num_copy = 0; - for (const HloInstruction* instruction : - module->entry_computation()->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - num_copy++; - } - } - EXPECT_EQ(num_copy, 2); -} - } // namespace } // namespace xla From 2b46421677c9433e22a9437f2ec41d1cdc32dd17 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Fri, 19 Jul 2024 15:06:55 -0700 Subject: [PATCH 031/376] In `build.py`, stop trying to pull container after first successful pull PiperOrigin-RevId: 654137443 --- build_tools/build.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/build_tools/build.py b/build_tools/build.py index 2d10390af1ca6b..848fbb407bef75 100755 --- a/build_tools/build.py +++ b/build_tools/build.py @@ -100,7 +100,9 @@ def _pull_docker_image_with_retries(self, retries=3) -> None: """Pulls docker image with retries to avoid transient rate limit errors.""" for _ in range(retries): pull_proc = sh(["docker", "pull", self.image_url], check=False) - if pull_proc.returncode != 0: + if pull_proc.returncode == 0: + break # Don't keep pulling after successful pull. + else: time.sleep(15) # write SHA of image to the sponge config From 89f05104d8d2d2b8a393eb43662458ea38f49a80 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 19 Jul 2024 15:24:23 -0700 Subject: [PATCH 032/376] Integrate StableHLO at openxla/stablehlo@531816f0 PiperOrigin-RevId: 654142429 --- third_party/stablehlo/temporary.patch | 12 ++ third_party/stablehlo/workspace.bzl | 4 +- .../mhlo/transforms/map_stablehlo_to_hlo_op.h | 1 + .../Dialect/chlo/chlo_legalize_to_mhlo.mlir | 156 +++++++++--------- 4 files changed, 93 insertions(+), 80 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b137891791fe9..e304a133e9af3a 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,13 @@ +diff --ruN a/stablehlo/stablehlo/reference/Tensor.cpp b/stablehlo/stablehlo/reference/Tensor.cpp +--- stablehlo/stablehlo/reference/Tensor.cpp ++++ stablehlo/stablehlo/reference/Tensor.cpp +@@ -423,7 +423,7 @@ + getType().print(os); + os << " {"; + Index idx{}; +- printHelper(os, *this, getShape(), idx, /*index=*/1); ++ printHelper(os, *this, getShape(), idx, /*indent=*/1); + os << "}"; + } + diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index c64a1b3410e1d6..48b2b101ae8a1a 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "7e749c84f933e092872ac0bcc2ecab40abf24f5d" - STABLEHLO_SHA256 = "5d49659462b7ab85f9683a64e5fba9c663c60af49d03527c51fcdb0b9c18bde4" + STABLEHLO_COMMIT = "531816f07e0db010a676c23fc66fe0a1a2e2d648" + STABLEHLO_SHA256 = "5a0b6a4dbe739793f1c4ea7d117aac81edaa18e2f2fe795fc3ffe6a2e9be2ac8" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h b/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h index 39d6de77380ae8..390dfb805c12c1 100644 --- a/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h +++ b/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h @@ -149,6 +149,7 @@ MAP_STABLEHLO_TO_HLO(SortOp) MAP_STABLEHLO_TO_HLO(SqrtOp) MAP_STABLEHLO_TO_HLO(SubtractOp) MAP_STABLEHLO_TO_HLO(TanhOp) +MAP_STABLEHLO_TO_HLO(TanOp) MAP_STABLEHLO_TO_HLO(TorchIndexSelectOp) MAP_STABLEHLO_TO_HLO(TransposeOp) MAP_STABLEHLO_TO_HLO(TriangularSolveOp) diff --git a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index 533a9ac640aa8a..4324c8e7731b2b 100644 --- a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -698,8 +698,8 @@ func.func @erf_bf16(%arg : tensor) -> tensor { // CHECK-LABEL: @acosh // CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { func.func @acosh(%arg: tensor) -> tensor { - // CHECK: %[[VAL_1:.*]] = mhlo.constant dense<6.550400e+04> : tensor - // CHECK: %[[VAL_2:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[VAL_1:.*]] = mhlo.constant dense<6.550400e+04> : tensor + // CHECK-DAG: %[[VAL_2:.*]] = mhlo.constant dense<2.000000e+00> : tensor // CHECK: %[[VAL_3:.*]] = mhlo.divide %[[VAL_1]], %[[VAL_2]] : tensor // CHECK: %[[VAL_4:.*]] = mhlo.compare GE, %[[VAL_0]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = mhlo.log %[[VAL_2]] : tensor @@ -753,8 +753,8 @@ func.func @acosh_complex_f32(%arg : tensor>) -> tensor // CHECK: %[[VAL_26:.*]] = mhlo.log %[[VAL_25]] : tensor // CHECK: %[[VAL_27:.*]] = mhlo.log %[[VAL_22]] : tensor // CHECK: %[[VAL_28:.*]] = mhlo.add %[[VAL_26]], %[[VAL_27]] : tensor - // CHECK: %[[VAL_29:.*]] = mhlo.constant dense<5.000000e-01> : tensor - // CHECK: %[[VAL_30:.*]] = mhlo.constant dense<0x7F800000> : tensor + // CHECK-DAG: %[[VAL_29:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK-DAG: %[[VAL_30:.*]] = mhlo.constant dense<0x7F800000> : tensor // CHECK: %[[VAL_31:.*]] = mhlo.compare EQ, %[[VAL_7]], %[[VAL_30]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_32:.*]] = mhlo.not %[[VAL_31]] : tensor // CHECK: %[[VAL_33:.*]] = mhlo.and %[[VAL_21]], %[[VAL_32]] : tensor @@ -969,13 +969,13 @@ func.func @erfc_f64(%arg : tensor) -> tensor { // CHECK-NEXT: %[[TMP_103:.*]] = mhlo.add %[[TMP_101]], %[[TMP_102]] // CHECK-NEXT: %[[TMP_104:.*]] = mhlo.divide %[[TMP_81]], %[[TMP_103]] // CHECK-NEXT: %[[TMP_105:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK-NEXT: %[[TMP_106:.*]] = mhlo.compare LT, %[[TMP_3]], %[[TMP_105]], NOTYPE + // CHECK-NEXT: %[[TMP_106:.*]] = mhlo.compare LT, %[[TMP_3]], %[[TMP_105]] // CHECK-NEXT: %[[TMP_107:.*]] = mhlo.select %[[TMP_106]], %[[TMP_61]], %[[TMP_104]] // CHECK-NEXT: %[[TMP_108:.*]] = mhlo.constant dense<-709.78271289338397> - // CHECK-NEXT: %[[TMP_109:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_108]], NOTYPE + // CHECK-NEXT: %[[TMP_109:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_108]] // CHECK-NEXT: %[[TMP_110:.*]] = mhlo.constant dense<0.000000e+00> // CHECK-NEXT: %[[TMP_111:.*]] = mhlo.select %[[TMP_109]], %[[TMP_110]], %[[TMP_107]] - // CHECK-NEXT: %[[TMP_113:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_110]], NOTYPE + // CHECK-NEXT: %[[TMP_113:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_110]] // CHECK-NEXT: %[[TMP_114:.*]] = mhlo.constant dense<2.000000e+00> // CHECK-NEXT: %[[TMP_115:.*]] = mhlo.subtract %[[TMP_114]], %[[TMP_111]] // CHECK-NEXT: %[[TMP_116:.*]] = mhlo.select %[[TMP_113]], %[[TMP_115]], %[[TMP_111]] @@ -1014,7 +1014,7 @@ func.func @erfc_f64(%arg : tensor) -> tensor { // CHECK-NEXT: %[[TMP_155:.*]] = mhlo.divide %[[TMP_135]], %[[TMP_154]] // CHECK-NEXT: %[[TMP_156:.*]] = mhlo.subtract %[[TMP_117]], %[[TMP_155]] // CHECK-NEXT: %[[TMP_157:.*]] = mhlo.abs %[[ARG]] - // CHECK-NEXT: %[[TMP_159:.*]] = mhlo.compare LT, %[[TMP_157]], %[[TMP_117]], NOTYPE + // CHECK-NEXT: %[[TMP_159:.*]] = mhlo.compare LT, %[[TMP_157]], %[[TMP_117]] // CHECK-NEXT: %[[RESULT:.*]] = mhlo.select %[[TMP_159]], %[[TMP_156]], %[[TMP_116]] // CHECK-NEXT: return %[[RESULT]] %1 = "chlo.erfc"(%arg) : (tensor) -> tensor @@ -1035,7 +1035,7 @@ func.func @erfc_f32(%arg : tensor) -> tensor { // CHECK: %[[TMP_7:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_2]] // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_5]], %[[TMP_7]] // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_10:.*]] = mhlo.compare LT, %[[TMP_2]], %[[TMP_9]], NOTYPE + // CHECK: %[[TMP_10:.*]] = mhlo.compare LT, %[[TMP_2]], %[[TMP_9]] // CHECK: %[[TMP_13:.*]] = mhlo.constant dense<2.326820e-02> // CHECK: %[[TMP_15:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_4]] // CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-0.138703942> @@ -1086,10 +1086,10 @@ func.func @erfc_f32(%arg : tensor) -> tensor { // CHECK: %[[TMP_64:.*]] = mhlo.select %[[TMP_10]], %[[TMP_38]], %[[TMP_63]] // CHECK: %[[TMP_65:.*]] = mhlo.multiply %[[TMP_8]], %[[TMP_64]] // CHECK: %[[TMP_66:.*]] = mhlo.constant dense<-88.7228394> - // CHECK: %[[TMP_67:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_66]], NOTYPE + // CHECK: %[[TMP_67:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_66]] // CHECK: %[[TMP_68:.*]] = mhlo.constant dense<0.000000e+00> // CHECK: %[[TMP_69:.*]] = mhlo.select %[[TMP_67]], %[[TMP_68]], %[[TMP_65]] - // CHECK: %[[TMP_71:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_68]], NOTYPE + // CHECK: %[[TMP_71:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_68]] // CHECK: %[[TMP_73:.*]] = mhlo.subtract %[[TMP_9]], %[[TMP_69]] // CHECK: %[[TMP_74:.*]] = mhlo.select %[[TMP_71]], %[[TMP_73]], %[[TMP_69]] // CHECK: %[[TMP_75:.*]] = mhlo.constant dense<1.000000e+00> @@ -1116,7 +1116,7 @@ func.func @erfc_f32(%arg : tensor) -> tensor { // CHECK: %[[TMP_99:.*]] = mhlo.multiply %[[ARG]], %[[TMP_98]] // CHECK: %[[TMP_100:.*]] = mhlo.subtract %[[TMP_75]], %[[TMP_99]] // CHECK: %[[TMP_101:.*]] = mhlo.abs %[[ARG]] - // CHECK: %[[TMP_103:.*]] = mhlo.compare LT, %[[TMP_101]], %[[TMP_75]], NOTYPE + // CHECK: %[[TMP_103:.*]] = mhlo.compare LT, %[[TMP_101]], %[[TMP_75]] // CHECK: %[[RESULT:.*]] = mhlo.select %[[TMP_103]], %[[TMP_100]], %[[TMP_74]] // CHECK: return %[[RESULT]] %1 = "chlo.erfc"(%arg) : (tensor) -> tensor @@ -1190,7 +1190,7 @@ func.func @is_neg_inf_f32(%arg : tensor) -> tensor { // CHECK-SAME: (%[[ARG:.*]]: tensor) func.func @lgamma_f64(%arg : tensor) -> tensor { // CHECK: %[[TMP_1:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_9:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_1]], NOTYPE + // CHECK: %[[TMP_9:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_1]] // CHECK: %[[TMP_10:.*]] = mhlo.negate %[[ARG]] // CHECK: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00> // CHECK: %[[TMP_11:.*]] = mhlo.subtract %[[ARG]], %[[TMP_2]] @@ -1253,7 +1253,7 @@ func.func @lgamma_f64(%arg : tensor) -> tensor { // CHECK: %[[TMP_64:.*]] = mhlo.abs %[[ARG]] // CHECK: %[[TMP_65:.*]] = mhlo.floor %[[TMP_64]] // CHECK: %[[TMP_66:.*]] = mhlo.subtract %[[TMP_64]], %[[TMP_65]] - // CHECK: %[[TMP_67:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_66]], NOTYPE + // CHECK: %[[TMP_67:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_66]] // CHECK: %[[TMP_68:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_66]] // CHECK: %[[TMP_69:.*]] = mhlo.select %[[TMP_67]], %[[TMP_68]], %[[TMP_66]] // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<3.1415926535897931> @@ -1283,7 +1283,7 @@ func.func @lgamma_f64(%arg : tensor) -> tensor { // CHECK-SAME: (%[[ARG:.*]]: tensor) func.func @lgamma_f32(%arg : tensor) -> tensor { // CHECK: %[[TMP_1:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_9:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_1]], NOTYPE + // CHECK: %[[TMP_9:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_1]] // CHECK: %[[TMP_10:.*]] = mhlo.negate %[[ARG]] // CHECK: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00> // CHECK: %[[TMP_11:.*]] = mhlo.subtract %[[ARG]], %[[TMP_2]] @@ -1346,7 +1346,7 @@ func.func @lgamma_f32(%arg : tensor) -> tensor { // CHECK: %[[TMP_64:.*]] = mhlo.abs %[[ARG]] // CHECK: %[[TMP_65:.*]] = mhlo.floor %[[TMP_64]] // CHECK: %[[TMP_66:.*]] = mhlo.subtract %[[TMP_64]], %[[TMP_65]] - // CHECK: %[[TMP_67:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_66]], NOTYPE + // CHECK: %[[TMP_67:.*]] = mhlo.compare LT, %[[TMP_1]], %[[TMP_66]] // CHECK: %[[TMP_68:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_66]] // CHECK: %[[TMP_69:.*]] = mhlo.select %[[TMP_67]], %[[TMP_68]], %[[TMP_66]] // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<3.14159274> @@ -1388,7 +1388,7 @@ func.func @lgamma_f16(%arg : tensor) -> tensor { // CHECK-SAME: (%[[ARG:.*]]: tensor) func.func @digamma_f64(%arg : tensor) -> tensor { // CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_1:.*]] = mhlo.compare LT, %arg0, %[[TMP_0]], NOTYPE + // CHECK: %[[TMP_1:.*]] = mhlo.compare LT, %arg0, %[[TMP_0]] // CHECK: %[[TMP_2:.*]] = mhlo.negate %arg0 // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00> // CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]] @@ -1483,9 +1483,9 @@ func.func @digamma_f64(%arg : tensor) -> tensor { // CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]] // CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]] // CHECK: %[[TMP_95:.*]] = mhlo.select %[[TMP_1]], %[[TMP_94]], %[[TMP_82]] - // CHECK: %[[TMP_96:.*]] = mhlo.compare LE, %arg0, %[[TMP_6]], NOTYPE + // CHECK: %[[TMP_96:.*]] = mhlo.compare LE, %arg0, %[[TMP_6]] // CHECK: %[[TMP_97:.*]] = mhlo.floor %arg0 - // CHECK: %[[TMP_98:.*]] = mhlo.compare EQ, %arg0, %[[TMP_97]], NOTYPE + // CHECK: %[[TMP_98:.*]] = mhlo.compare EQ, %arg0, %[[TMP_97]] // CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]] // CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FF8000000000000> // CHECK: %[[RES:.*]] = mhlo.select %[[TMP_99]], %[[TMP_100]], %[[TMP_95]] @@ -1500,7 +1500,7 @@ func.func @digamma_f64(%arg : tensor) -> tensor { // CHECK-SAME: (%[[ARG:.*]]: tensor) func.func @digamma_f32(%arg : tensor) -> tensor { // CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_1:.*]] = mhlo.compare LT, %arg0, %[[TMP_0]], NOTYPE + // CHECK: %[[TMP_1:.*]] = mhlo.compare LT, %arg0, %[[TMP_0]] // CHECK: %[[TMP_2:.*]] = mhlo.negate %arg0 // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00> // CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]] @@ -1595,9 +1595,9 @@ func.func @digamma_f32(%arg : tensor) -> tensor { // CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]] // CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]] // CHECK: %[[TMP_95:.*]] = mhlo.select %[[TMP_1]], %[[TMP_94]], %[[TMP_82]] - // CHECK: %[[TMP_96:.*]] = mhlo.compare LE, %arg0, %[[TMP_6]], NOTYPE + // CHECK: %[[TMP_96:.*]] = mhlo.compare LE, %arg0, %[[TMP_6]] // CHECK: %[[TMP_97:.*]] = mhlo.floor %arg0 - // CHECK: %[[TMP_98:.*]] = mhlo.compare EQ, %arg0, %[[TMP_97]], NOTYPE + // CHECK: %[[TMP_98:.*]] = mhlo.compare EQ, %arg0, %[[TMP_97]] // CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]] // CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FC00000> // CHECK: %[[RES:.*]] = mhlo.select %[[TMP_99]], %[[TMP_100]], %[[TMP_95]] @@ -1776,29 +1776,29 @@ func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[TMP_150:.*]] = mhlo.abs %[[TMP_32]] // CHECK: %[[TMP_151:.*]] = mhlo.constant dense<1.401300e-45> // CHECK: %[[TMP_152:.*]] = mhlo.multiply %[[TMP_150]], %[[TMP_151]] - // CHECK: %[[TMP_153:.*]] = mhlo.compare LT, %[[TMP_149]], %[[TMP_152]], NOTYPE + // CHECK: %[[TMP_153:.*]] = mhlo.compare LT, %[[TMP_149]], %[[TMP_152]] // CHECK: %[[TMP_154:.*]] = mhlo.select %[[TMP_153]], %[[TMP_32]], %[[TMP_148]] // CHECK: %[[TMP_155:.*]] = mhlo.constant dense<0x7FC00000> - // CHECK: %[[TMP_156:.*]] = mhlo.compare LT, %[[TMP_0]], %[[TMP_35]], NOTYPE + // CHECK: %[[TMP_156:.*]] = mhlo.compare LT, %[[TMP_0]], %[[TMP_35]] // CHECK: %[[TMP_157:.*]] = mhlo.select %[[TMP_156]], %[[TMP_155]], %[[TMP_154]] - // CHECK: %[[TMP_158:.*]] = mhlo.compare LE, %[[TMP_1]], %[[TMP_2]], NOTYPE + // CHECK: %[[TMP_158:.*]] = mhlo.compare LE, %[[TMP_1]], %[[TMP_2]] // CHECK: %[[TMP_159:.*]] = mhlo.floor %[[TMP_0]] - // CHECK: %[[TMP_160:.*]] = mhlo.compare NE, %[[TMP_0]], %[[TMP_159]], NOTYPE + // CHECK: %[[TMP_160:.*]] = mhlo.compare NE, %[[TMP_0]], %[[TMP_159]] // CHECK: %[[TMP_161:.*]] = mhlo.and %[[TMP_158]], %[[TMP_160]] : tensor // CHECK: %[[TMP_162:.*]] = mhlo.select %[[TMP_161]], %[[TMP_155]], %[[TMP_157]] // CHECK: %[[TMP_163:.*]] = mhlo.constant dense<0x7F800000> // CHECK: %[[TMP_164:.*]] = mhlo.floor %[[TMP_1]] - // CHECK: %[[TMP_165:.*]] = mhlo.compare EQ, %[[TMP_1]], %[[TMP_164]], NOTYPE + // CHECK: %[[TMP_165:.*]] = mhlo.compare EQ, %[[TMP_1]], %[[TMP_164]] // CHECK: %[[TMP_166:.*]] = mhlo.and %[[TMP_158]], %[[TMP_165]] : tensor // CHECK: %[[TMP_167:.*]] = mhlo.constant dense<2.000000e+00> // CHECK: %[[TMP_168:.*]] = mhlo.floor %[[TMP_0]] - // CHECK: %[[TMP_169:.*]] = mhlo.compare EQ, %[[TMP_0]], %[[TMP_168]], NOTYPE + // CHECK: %[[TMP_169:.*]] = mhlo.compare EQ, %[[TMP_0]], %[[TMP_168]] // CHECK: %[[TMP_170:.*]] = mhlo.remainder %[[TMP_0]], %[[TMP_167]] - // CHECK: %[[TMP_171:.*]] = mhlo.compare EQ, %[[TMP_170]], %[[TMP_2]], NOTYPE + // CHECK: %[[TMP_171:.*]] = mhlo.compare EQ, %[[TMP_170]], %[[TMP_2]] // CHECK: %[[TMP_172:.*]] = mhlo.and %[[TMP_169]], %[[TMP_171]] : tensor // CHECK: %[[TMP_173:.*]] = mhlo.select %[[TMP_172]], %[[TMP_163]], %[[TMP_155]] // CHECK: %[[TMP_174:.*]] = mhlo.select %[[TMP_166]], %[[TMP_173]], %[[TMP_162]] - // CHECK: %[[TMP_175:.*]] = mhlo.compare EQ, %[[TMP_0]], %[[TMP_3]], NOTYPE + // CHECK: %[[TMP_175:.*]] = mhlo.compare EQ, %[[TMP_0]], %[[TMP_3]] // CHECK: %[[TMP_176:.*]] = mhlo.select %[[TMP_175]], %[[TMP_163]], %[[TMP_174]] // CHECK: %[[TMP_177:.*]] = mhlo.convert %[[TMP_176]] : (tensor) -> tensor %0 = chlo.zeta %arg0, %arg1 : tensor, tensor -> tensor @@ -1817,7 +1817,7 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_4:.*]] = mhlo.subtract %[[TMP_3]], %[[TMP_0]] // CHECK: %[[TMP_5:.*]] = mhlo.add %[[ARG0]], %[[TMP_0]] // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_7:.*]] = mhlo.compare LT, %[[TMP_5]], %[[TMP_6]], NOTYPE + // CHECK: %[[TMP_7:.*]] = mhlo.compare LT, %[[TMP_5]], %[[TMP_6]] // CHECK: %[[TMP_8:.*]] = mhlo.negate %[[TMP_5]] // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00> // CHECK: %[[TMP_10:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_9]] @@ -1880,7 +1880,7 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_67:.*]] = mhlo.abs %[[TMP_5]] // CHECK: %[[TMP_68:.*]] = mhlo.floor %[[TMP_67]] // CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_67]], %[[TMP_68]] - // CHECK: %[[TMP_70:.*]] = mhlo.compare LT, %[[TMP_6]], %[[TMP_69]], NOTYPE + // CHECK: %[[TMP_70:.*]] = mhlo.compare LT, %[[TMP_6]], %[[TMP_69]] // CHECK: %[[TMP_71:.*]] = mhlo.subtract %[[TMP_9]], %[[TMP_69]] // CHECK: %[[TMP_72:.*]] = mhlo.select %[[TMP_70]], %[[TMP_71]], %[[TMP_69]] // CHECK: %[[TMP_73:.*]] = mhlo.constant dense<3.14159274> @@ -2051,36 +2051,36 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_238:.*]] = mhlo.abs %[[TMP_120]] // CHECK: %[[TMP_239:.*]] = mhlo.constant dense<1.401300e-45> // CHECK: %[[TMP_240:.*]] = mhlo.multiply %[[TMP_238]], %[[TMP_239]] - // CHECK: %[[TMP_241:.*]] = mhlo.compare LT, %[[TMP_237]], %[[TMP_240]], NOTYPE + // CHECK: %[[TMP_241:.*]] = mhlo.compare LT, %[[TMP_237]], %[[TMP_240]] // CHECK: %[[TMP_242:.*]] = mhlo.select %[[TMP_241]], %[[TMP_120]], %[[TMP_236]] // CHECK: %[[TMP_243:.*]] = mhlo.constant dense<0x7FC00000> - // CHECK: %[[TMP_244:.*]] = mhlo.compare LT, %[[TMP_5]], %[[TMP_123]], NOTYPE + // CHECK: %[[TMP_244:.*]] = mhlo.compare LT, %[[TMP_5]], %[[TMP_123]] // CHECK: %[[TMP_245:.*]] = mhlo.select %[[TMP_244]], %[[TMP_243]], %[[TMP_242]] - // CHECK: %[[TMP_246:.*]] = mhlo.compare LE, %[[ARG1]], %[[TMP_90]], NOTYPE + // CHECK: %[[TMP_246:.*]] = mhlo.compare LE, %[[ARG1]], %[[TMP_90]] // CHECK: %[[TMP_247:.*]] = mhlo.floor %[[TMP_5]] - // CHECK: %[[TMP_248:.*]] = mhlo.compare NE, %[[TMP_5]], %[[TMP_247]], NOTYPE + // CHECK: %[[TMP_248:.*]] = mhlo.compare NE, %[[TMP_5]], %[[TMP_247]] // CHECK: %[[TMP_249:.*]] = mhlo.and %[[TMP_246]], %[[TMP_248]] // CHECK: %[[TMP_250:.*]] = mhlo.select %[[TMP_249]], %[[TMP_243]], %[[TMP_245]] // CHECK: %[[TMP_251:.*]] = mhlo.constant dense<0x7F800000> // CHECK: %[[TMP_252:.*]] = mhlo.floor %[[ARG1]] - // CHECK: %[[TMP_253:.*]] = mhlo.compare EQ, %[[ARG1]], %[[TMP_252]], NOTYPE + // CHECK: %[[TMP_253:.*]] = mhlo.compare EQ, %[[ARG1]], %[[TMP_252]] // CHECK: %[[TMP_254:.*]] = mhlo.and %[[TMP_246]], %[[TMP_253]] // CHECK: %[[TMP_255:.*]] = mhlo.constant dense<2.000000e+00> // CHECK: %[[TMP_256:.*]] = mhlo.floor %[[TMP_5]] - // CHECK: %[[TMP_257:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_256]], NOTYPE + // CHECK: %[[TMP_257:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_256]] // CHECK: %[[TMP_258:.*]] = mhlo.remainder %[[TMP_5]], %[[TMP_255]] - // CHECK: %[[TMP_259:.*]] = mhlo.compare EQ, %[[TMP_258]], %[[TMP_90]], NOTYPE + // CHECK: %[[TMP_259:.*]] = mhlo.compare EQ, %[[TMP_258]], %[[TMP_90]] // CHECK: %[[TMP_260:.*]] = mhlo.and %[[TMP_257]], %[[TMP_259]] // CHECK: %[[TMP_261:.*]] = mhlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] // CHECK: %[[TMP_262:.*]] = mhlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] - // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE + // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_91]] // CHECK: %[[TMP_264:.*]] = mhlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] // CHECK: %[[TMP_265:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_89]] // CHECK: %[[TMP_266:.*]] = mhlo.multiply %[[TMP_265]], %[[TMP_264]] // CHECK: %[[TMP_267:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_268:.*]] = mhlo.compare EQ, %[[ARG0]], %[[TMP_267]], NOTYPE + // CHECK: %[[TMP_268:.*]] = mhlo.compare EQ, %[[ARG0]], %[[TMP_267]] // CHECK: %[[TMP_269:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_270:.*]] = mhlo.compare LT, %[[ARG1]], %[[TMP_269]], NOTYPE + // CHECK: %[[TMP_270:.*]] = mhlo.compare LT, %[[ARG1]], %[[TMP_269]] // CHECK: %[[TMP_271:.*]] = mhlo.negate %[[ARG1]] // CHECK: %[[TMP_272:.*]] = mhlo.constant dense<1.000000e+00> // CHECK: %[[TMP_273:.*]] = mhlo.subtract %[[ARG1]], %[[TMP_272]] @@ -2175,16 +2175,16 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_362:.*]] = mhlo.divide %[[TMP_361]], %[[TMP_360]] // CHECK: %[[TMP_363:.*]] = mhlo.subtract %[[TMP_351]], %[[TMP_362]] // CHECK: %[[TMP_364:.*]] = mhlo.select %[[TMP_270]], %[[TMP_363]], %[[TMP_351]] - // CHECK: %[[TMP_365:.*]] = mhlo.compare LE, %[[ARG1]], %[[TMP_275]], NOTYPE + // CHECK: %[[TMP_365:.*]] = mhlo.compare LE, %[[ARG1]], %[[TMP_275]] // CHECK: %[[TMP_366:.*]] = mhlo.floor %[[ARG1]] - // CHECK: %[[TMP_367:.*]] = mhlo.compare EQ, %[[ARG1]], %[[TMP_366]], NOTYPE + // CHECK: %[[TMP_367:.*]] = mhlo.compare EQ, %[[ARG1]], %[[TMP_366]] // CHECK: %[[TMP_368:.*]] = mhlo.and %[[TMP_365]], %[[TMP_367]] // CHECK: %[[TMP_369:.*]] = mhlo.constant dense<0x7FC00000> // CHECK: %[[TMP_370:.*]] = mhlo.select %[[TMP_368]], %[[TMP_369]], %[[TMP_364]] // CHECK: %[[TMP_371:.*]] = mhlo.select %[[TMP_268]], %[[TMP_370]], %[[TMP_266]] // CHECK: %[[TMP_372:.*]] = mhlo.floor %[[ARG0]] - // CHECK: %[[TMP_373:.*]] = mhlo.compare NE, %[[ARG0]], %[[TMP_372]], NOTYPE - // CHECK: %[[TMP_374:.*]] = mhlo.compare LT, %[[ARG0]], %[[TMP_267]], NOTYPE + // CHECK: %[[TMP_373:.*]] = mhlo.compare NE, %[[ARG0]], %[[TMP_372]] + // CHECK: %[[TMP_374:.*]] = mhlo.compare LT, %[[ARG0]], %[[TMP_267]] // CHECK: %[[TMP_375:.*]] = mhlo.or %[[TMP_373]], %[[TMP_374]] // CHECK: %[[TMP_376:.*]] = mhlo.constant dense<0x7FC00000> // CHECK: %[[TMP_377:.*]] = mhlo.select %[[TMP_375]], %[[TMP_376]], %[[TMP_371]] @@ -2204,7 +2204,7 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_4:.*]] = mhlo.subtract %[[TMP_3]], %[[TMP_0]] // CHECK: %[[TMP_5:.*]] = mhlo.add %[[ARG0]], %[[TMP_0]] // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_7:.*]] = mhlo.compare LT, %[[TMP_5]], %[[TMP_6]], NOTYPE + // CHECK: %[[TMP_7:.*]] = mhlo.compare LT, %[[TMP_5]], %[[TMP_6]] // CHECK: %[[TMP_8:.*]] = mhlo.negate %[[TMP_5]] // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00> // CHECK: %[[TMP_10:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_9]] @@ -2267,7 +2267,7 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_67:.*]] = mhlo.abs %[[TMP_5]] // CHECK: %[[TMP_68:.*]] = mhlo.floor %[[TMP_67]] // CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_67]], %[[TMP_68]] - // CHECK: %[[TMP_70:.*]] = mhlo.compare LT, %[[TMP_6]], %[[TMP_69]], NOTYPE + // CHECK: %[[TMP_70:.*]] = mhlo.compare LT, %[[TMP_6]], %[[TMP_69]] // CHECK: %[[TMP_71:.*]] = mhlo.subtract %[[TMP_9]], %[[TMP_69]] // CHECK: %[[TMP_72:.*]] = mhlo.select %[[TMP_70]], %[[TMP_71]], %[[TMP_69]] // CHECK: %[[TMP_73:.*]] = mhlo.constant dense<3.1415926535897931> @@ -2438,36 +2438,36 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_238:.*]] = mhlo.abs %[[TMP_120]] // CHECK: %[[TMP_239:.*]] = mhlo.constant dense<4.940660e-324> // CHECK: %[[TMP_240:.*]] = mhlo.multiply %[[TMP_238]], %[[TMP_239]] - // CHECK: %[[TMP_241:.*]] = mhlo.compare LT, %[[TMP_237]], %[[TMP_240]], NOTYPE + // CHECK: %[[TMP_241:.*]] = mhlo.compare LT, %[[TMP_237]], %[[TMP_240]] // CHECK: %[[TMP_242:.*]] = mhlo.select %[[TMP_241]], %[[TMP_120]], %[[TMP_236]] // CHECK: %[[TMP_243:.*]] = mhlo.constant dense<0x7FF8000000000000> - // CHECK: %[[TMP_244:.*]] = mhlo.compare LT, %[[TMP_5]], %[[TMP_123]], NOTYPE + // CHECK: %[[TMP_244:.*]] = mhlo.compare LT, %[[TMP_5]], %[[TMP_123]] // CHECK: %[[TMP_245:.*]] = mhlo.select %[[TMP_244]], %[[TMP_243]], %[[TMP_242]] - // CHECK: %[[TMP_246:.*]] = mhlo.compare LE, %[[ARG1]], %[[TMP_90]], NOTYPE + // CHECK: %[[TMP_246:.*]] = mhlo.compare LE, %[[ARG1]], %[[TMP_90]] // CHECK: %[[TMP_247:.*]] = mhlo.floor %[[TMP_5]] - // CHECK: %[[TMP_248:.*]] = mhlo.compare NE, %[[TMP_5]], %[[TMP_247]], NOTYPE + // CHECK: %[[TMP_248:.*]] = mhlo.compare NE, %[[TMP_5]], %[[TMP_247]] // CHECK: %[[TMP_249:.*]] = mhlo.and %[[TMP_246]], %[[TMP_248]] // CHECK: %[[TMP_250:.*]] = mhlo.select %[[TMP_249]], %[[TMP_243]], %[[TMP_245]] // CHECK: %[[TMP_251:.*]] = mhlo.constant dense<0x7FF0000000000000> // CHECK: %[[TMP_252:.*]] = mhlo.floor %[[ARG1]] - // CHECK: %[[TMP_253:.*]] = mhlo.compare EQ, %[[ARG1]], %[[TMP_252]], NOTYPE + // CHECK: %[[TMP_253:.*]] = mhlo.compare EQ, %[[ARG1]], %[[TMP_252]] // CHECK: %[[TMP_254:.*]] = mhlo.and %[[TMP_246]], %[[TMP_253]] // CHECK: %[[TMP_255:.*]] = mhlo.constant dense<2.000000e+00> // CHECK: %[[TMP_256:.*]] = mhlo.floor %[[TMP_5]] - // CHECK: %[[TMP_257:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_256]], NOTYPE + // CHECK: %[[TMP_257:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_256]] // CHECK: %[[TMP_258:.*]] = mhlo.remainder %[[TMP_5]], %[[TMP_255]] - // CHECK: %[[TMP_259:.*]] = mhlo.compare EQ, %[[TMP_258]], %[[TMP_90]], NOTYPE + // CHECK: %[[TMP_259:.*]] = mhlo.compare EQ, %[[TMP_258]], %[[TMP_90]] // CHECK: %[[TMP_260:.*]] = mhlo.and %[[TMP_257]], %[[TMP_259]] // CHECK: %[[TMP_261:.*]] = mhlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] // CHECK: %[[TMP_262:.*]] = mhlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] - // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE + // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_91]] // CHECK: %[[TMP_264:.*]] = mhlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] // CHECK: %[[TMP_265:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_89]] // CHECK: %[[TMP_266:.*]] = mhlo.multiply %[[TMP_265]], %[[TMP_264]] // CHECK: %[[TMP_267:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_268:.*]] = mhlo.compare EQ, %[[ARG0]], %[[TMP_267]], NOTYPE + // CHECK: %[[TMP_268:.*]] = mhlo.compare EQ, %[[ARG0]], %[[TMP_267]] // CHECK: %[[TMP_269:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_270:.*]] = mhlo.compare LT, %[[ARG1]], %[[TMP_269]], NOTYPE + // CHECK: %[[TMP_270:.*]] = mhlo.compare LT, %[[ARG1]], %[[TMP_269]] // CHECK: %[[TMP_271:.*]] = mhlo.negate %[[ARG1]] // CHECK: %[[TMP_272:.*]] = mhlo.constant dense<1.000000e+00> // CHECK: %[[TMP_273:.*]] = mhlo.subtract %[[ARG1]], %[[TMP_272]] @@ -2562,16 +2562,16 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_362:.*]] = mhlo.divide %[[TMP_361]], %[[TMP_360]] // CHECK: %[[TMP_363:.*]] = mhlo.subtract %[[TMP_351]], %[[TMP_362]] // CHECK: %[[TMP_364:.*]] = mhlo.select %[[TMP_270]], %[[TMP_363]], %[[TMP_351]] - // CHECK: %[[TMP_365:.*]] = mhlo.compare LE, %[[ARG1]], %[[TMP_275]], NOTYPE + // CHECK: %[[TMP_365:.*]] = mhlo.compare LE, %[[ARG1]], %[[TMP_275]] // CHECK: %[[TMP_366:.*]] = mhlo.floor %[[ARG1]] - // CHECK: %[[TMP_367:.*]] = mhlo.compare EQ, %[[ARG1]], %[[TMP_366]], NOTYPE + // CHECK: %[[TMP_367:.*]] = mhlo.compare EQ, %[[ARG1]], %[[TMP_366]] // CHECK: %[[TMP_368:.*]] = mhlo.and %[[TMP_365]], %[[TMP_367]] // CHECK: %[[TMP_369:.*]] = mhlo.constant dense<0x7FF8000000000000> // CHECK: %[[TMP_370:.*]] = mhlo.select %[[TMP_368]], %[[TMP_369]], %[[TMP_364]] // CHECK: %[[TMP_371:.*]] = mhlo.select %[[TMP_268]], %[[TMP_370]], %[[TMP_266]] // CHECK: %[[TMP_372:.*]] = mhlo.floor %[[ARG0]] - // CHECK: %[[TMP_373:.*]] = mhlo.compare NE, %[[ARG0]], %[[TMP_372]], NOTYPE - // CHECK: %[[TMP_374:.*]] = mhlo.compare LT, %[[ARG0]], %[[TMP_267]], NOTYPE + // CHECK: %[[TMP_373:.*]] = mhlo.compare NE, %[[ARG0]], %[[TMP_372]] + // CHECK: %[[TMP_374:.*]] = mhlo.compare LT, %[[ARG0]], %[[TMP_267]] // CHECK: %[[TMP_375:.*]] = mhlo.or %[[TMP_373]], %[[TMP_374]] // CHECK: %[[TMP_376:.*]] = mhlo.constant dense<0x7FF8000000000000> // CHECK: %[[TMP_377:.*]] = mhlo.select %[[TMP_375]], %[[TMP_376]], %[[TMP_371]] @@ -2613,7 +2613,7 @@ func.func @sinh_f32(%x : tensor) -> tensor { // CHECK: %[[SUM:.*]] = mhlo.add %[[EXPM1]], %[[RATIO]] : tensor // CHECK: %[[SMALL_SINH_RESULT:.*]] = mhlo.multiply %[[HALF]], %[[SUM]] : tensor // CHECK: %[[ABS_X:.*]] = mhlo.abs %[[X]] : tensor - // CHECK: %[[ABS_X_LT_ONE:.*]] = mhlo.compare LT, %[[ABS_X]], %[[ONE]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: %[[ABS_X_LT_ONE:.*]] = mhlo.compare LT, %[[ABS_X]], %[[ONE]] : (tensor, tensor) -> tensor // CHECK: %[[RESULT:.*]] = mhlo.select %[[ABS_X_LT_ONE]], %[[SMALL_SINH_RESULT]], %[[LARGE_SINH_RESULT]] : tensor, tensor // CHECK: return %[[RESULT]] : tensor %1 = chlo.sinh %x : tensor -> tensor @@ -2739,8 +2739,8 @@ func.func @atanh_complex_f32(%arg : tensor>) -> tensor func.func @next_after_f32(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[X_AS_INT:.*]] = mhlo.bitcast_convert %[[ARG0]] : (tensor<2xf32>) -> tensor<2xi32> // CHECK: %[[Y_AS_INT:.*]] = mhlo.bitcast_convert %[[ARG1]] : (tensor<2xf32>) -> tensor<2xi32> - // CHECK: %[[X_IS_NAN:.*]] = mhlo.compare NE, %[[ARG0]], %[[ARG0]], NOTYPE : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - // CHECK: %[[Y_IS_NAN:.*]] = mhlo.compare NE, %[[ARG1]], %[[ARG1]], NOTYPE : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHECK: %[[X_IS_NAN:.*]] = mhlo.compare NE, %[[ARG0]], %[[ARG0]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHECK: %[[Y_IS_NAN:.*]] = mhlo.compare NE, %[[ARG1]], %[[ARG1]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> // CHECK: %[[INPUT_IS_NAN:.*]] = mhlo.or %[[X_IS_NAN]], %[[Y_IS_NAN]] : tensor<2xi1> // CHECK: %[[NAN:.*]] = mhlo.constant dense<0x7FC00000> : tensor<2xf32> // CHECK: %[[NAN_AS_INT:.*]] = mhlo.bitcast_convert %[[NAN]] : (tensor<2xf32>) -> tensor<2xi32> @@ -2748,16 +2748,16 @@ func.func @next_after_f32(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32> // CHECK-DAG: %[[NEGATED_SIGN_MASK:.*]] = mhlo.constant dense<2147483647> : tensor<2xi32> // CHECK: %[[X_ABS:.*]] = mhlo.and %[[X_AS_INT]], %[[NEGATED_SIGN_MASK]] : tensor<2xi32> // CHECK: %[[Y_ABS:.*]] = mhlo.and %[[Y_AS_INT]], %[[NEGATED_SIGN_MASK]] : tensor<2xi32> - // CHECK: %[[X_AND_Y_ARE_EQUAL:.*]] = mhlo.compare EQ, %[[ARG0]], %[[ARG1]], NOTYPE : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHECK: %[[X_AND_Y_ARE_EQUAL:.*]] = mhlo.compare EQ, %[[ARG0]], %[[ARG1]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<2xi32> - // CHECK: %[[X_ABS_IS_ZERO:.*]] = mhlo.compare EQ, %[[X_ABS]], %[[ZERO]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK: %[[Y_ABS_IS_ZERO:.*]] = mhlo.compare EQ, %[[Y_ABS]], %[[ZERO]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK: %[[X_ABS_IS_ZERO:.*]] = mhlo.compare EQ, %[[X_ABS]], %[[ZERO]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK: %[[Y_ABS_IS_ZERO:.*]] = mhlo.compare EQ, %[[Y_ABS]], %[[ZERO]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> // CHECK: %[[X_SIGN:.*]] = mhlo.and %[[X_AS_INT]], %[[SIGN_MASK]] : tensor<2xi32> // CHECK: %[[Y_SIGN:.*]] = mhlo.and %[[Y_AS_INT]], %[[SIGN_MASK]] : tensor<2xi32> // CHECK: %[[ONE:.*]] = mhlo.constant dense<1> : tensor<2xi32> // CHECK: %[[RESULT_FOR_X_ZERO_Y_NON_ZERO:.*]] = mhlo.or %[[Y_SIGN]], %[[ONE]] : tensor<2xi32> - // CHECK: %[[SIGNS_DISAGREE:.*]] = mhlo.compare NE, %[[X_SIGN]], %[[Y_SIGN]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - // CHECK: %[[X_MAGNITUDE_LARGER_THAN_Y:.*]] = mhlo.compare GT, %[[X_ABS]], %[[Y_ABS]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK: %[[SIGNS_DISAGREE:.*]] = mhlo.compare NE, %[[X_SIGN]], %[[Y_SIGN]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK: %[[X_MAGNITUDE_LARGER_THAN_Y:.*]] = mhlo.compare GT, %[[X_ABS]], %[[Y_ABS]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> // CHECK: %[[RESULT_HAS_SMALLER_MAGNITUDE:.*]] = mhlo.or %[[X_MAGNITUDE_LARGER_THAN_Y]], %[[SIGNS_DISAGREE]] : tensor<2xi1> // CHECK: %[[MINUS_ONE:.*]] = mhlo.constant dense<-1> : tensor<2xi32> // CHECK: %[[MAGNITUDE_ADJUSTMENT:.*]] = mhlo.select %[[RESULT_HAS_SMALLER_MAGNITUDE]], %[[MINUS_ONE]], %[[ONE]] : tensor<2xi1>, tensor<2xi32> @@ -2956,7 +2956,7 @@ func.func @bessel_i1e_f16(%arg: tensor<16x16xf16>) -> tensor<16x16xf16> { // CHECK-NEXT: %[[TMP_118:.*]] = mhlo.multiply %[[TMP_116]], %[[TMP_117]] : tensor<16x16xf32> // CHECK-NEXT: %[[TMP_119:.*]] = mhlo.sqrt %[[TMP_1]] : tensor<16x16xf32> // CHECK-NEXT: %[[TMP_120:.*]] = mhlo.divide %[[TMP_118]], %[[TMP_119]] : tensor<16x16xf32> - // CHECK-NEXT: %[[TMP_121:.*]] = mhlo.compare LE, %[[TMP_1]], %[[TMP_5]], NOTYPE : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> + // CHECK-NEXT: %[[TMP_121:.*]] = mhlo.compare LE, %[[TMP_1]], %[[TMP_5]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> // CHECK-NEXT: %[[TMP_122:.*]] = mhlo.select %[[TMP_121]], %[[TMP_82]], %[[TMP_120]] : tensor<16x16xi1>, tensor<16x16xf32> // CHECK-NEXT: %[[TMP_123:.*]] = mhlo.sign %[[TMP_0]] : tensor<16x16xf32> // CHECK-NEXT: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_123]], %[[TMP_122]] : tensor<16x16xf32> @@ -3091,7 +3091,7 @@ func.func @bessel_i1e_f32(%arg : tensor<16x16xf32>) -> tensor<16x16xf32> { // CHECK-NEXT: %[[TMP_117:.*]] = mhlo.multiply %[[TMP_115]], %[[TMP_116]] : tensor<16x16xf32> // CHECK-NEXT: %[[TMP_118:.*]] = mhlo.sqrt %[[TMP_0]] : tensor<16x16xf32> // CHECK-NEXT: %[[TMP_119:.*]] = mhlo.divide %[[TMP_117]], %[[TMP_118]] : tensor<16x16xf32> - // CHECK-NEXT: %[[TMP_120:.*]] = mhlo.compare LE, %[[TMP_0]], %[[TMP_4]], NOTYPE : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> + // CHECK-NEXT: %[[TMP_120:.*]] = mhlo.compare LE, %[[TMP_0]], %[[TMP_4]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> // CHECK-NEXT: %[[TMP_121:.*]] = mhlo.select %[[TMP_120]], %[[TMP_81]], %[[TMP_119]] : tensor<16x16xi1>, tensor<16x16xf32> // CHECK-NEXT: %[[TMP_122:.*]] = mhlo.sign %[[ARG0]] : tensor<16x16xf32> // CHECK-NEXT: %[[TMP_123:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] : tensor<16x16xf32> @@ -3345,7 +3345,7 @@ func.func @bessel_i1e_f64(%arg : tensor<16x16xf64>) -> tensor<16x16xf64> { // CHECK-NEXT: %[[TMP_237:.*]] = mhlo.multiply %[[TMP_235]], %[[TMP_236]] : tensor<16x16xf64> // CHECK-NEXT: %[[TMP_238:.*]] = mhlo.sqrt %[[TMP_0]] : tensor<16x16xf64> // CHECK-NEXT: %[[TMP_239:.*]] = mhlo.divide %[[TMP_237]], %[[TMP_238]] : tensor<16x16xf64> - // CHECK-NEXT: %[[TMP_240:.*]] = mhlo.compare LE, %[[TMP_0]], %[[TMP_4]], NOTYPE : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> + // CHECK-NEXT: %[[TMP_240:.*]] = mhlo.compare LE, %[[TMP_0]], %[[TMP_4]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> // CHECK-NEXT: %[[TMP_241:.*]] = mhlo.select %[[TMP_240]], %[[TMP_129]], %[[TMP_239]] : tensor<16x16xi1>, tensor<16x16xf64> // CHECK-NEXT: %[[TMP_242:.*]] = mhlo.sign %[[ARG0]] : tensor<16x16xf64> // CHECK-NEXT: %[[TMP_243:.*]] = mhlo.multiply %[[TMP_242]], %[[TMP_241]] : tensor<16x16xf64> @@ -3360,7 +3360,7 @@ func.func @bessel_i1e_f64(%arg : tensor<16x16xf64>) -> tensor<16x16xf64> { // CHECK-DAG: [[VAL_2:%.*]] = mhlo.log_plus_one [[VAL_1]] : tensor<16x16xf32> // CHECK-DAG: [[VAL_3:%.*]] = mhlo.negate [[VAL_2]] : tensor<16x16xf32> // CHECK-DAG: [[VAL_4:%.*]] = mhlo.constant dense<5.000000e+00> : tensor<16x16xf32> -// CHECK-DAG: [[VAL_5:%.*]] = mhlo.compare LT, [[VAL_3]], [[VAL_4]], NOTYPE : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> +// CHECK-DAG: [[VAL_5:%.*]] = mhlo.compare LT, [[VAL_3]], [[VAL_4]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> // CHECK-DAG: [[VAL_6:%.*]] = mhlo.constant dense<2.500000e+00> : tensor<16x16xf32> // CHECK-DAG: [[VAL_7:%.*]] = mhlo.subtract [[VAL_3]], [[VAL_6]] : tensor<16x16xf32> // CHECK-DAG: [[VAL_8:%.*]] = mhlo.sqrt [[VAL_3]] : tensor<16x16xf32> @@ -3413,7 +3413,7 @@ func.func @bessel_i1e_f64(%arg : tensor<16x16xf64>) -> tensor<16x16xf64> { // CHECK-DAG: [[VAL_55:%.*]] = mhlo.multiply [[VAL_54]], [[ARG_0]] : tensor<16x16xf32> // CHECK-DAG: [[VAL_56:%.*]] = mhlo.abs [[ARG_0]] : tensor<16x16xf32> // CHECK-DAG: [[VAL_57:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<16x16xf32> -// CHECK-DAG: [[VAL_58:%.*]] = mhlo.compare EQ, [[VAL_56]], [[VAL_57]], NOTYPE : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> +// CHECK-DAG: [[VAL_58:%.*]] = mhlo.compare EQ, [[VAL_56]], [[VAL_57]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> // CHECK-DAG: [[VAL_59:%.*]] = mhlo.constant dense<0x7F800000> : tensor<16x16xf32> // CHECK-DAG: [[VAL_60:%.*]] = mhlo.multiply [[ARG_0]], [[VAL_59]] : tensor<16x16xf32> // CHECK-DAG: [[VAL_61:%.*]] = mhlo.select [[VAL_58]], [[VAL_60]], [[VAL_55]] : tensor<16x16xi1>, tensor<16x16xf32> @@ -3428,9 +3428,9 @@ func.func @erf_inv(%arg0 : tensor<16x16xf32>) { // CHECK-DAG: [[VAL_2:%.*]] = mhlo.log_plus_one [[VAL_1]] : tensor<16x16xf64> // CHECK-DAG: [[VAL_3:%.*]] = mhlo.negate [[VAL_2]] : tensor<16x16xf64> // CHECK-DAG: [[VAL_4:%.*]] = mhlo.constant dense<6.250000e+00> : tensor<16x16xf64> -// CHECK-DAG: [[VAL_5:%.*]] = mhlo.compare LT, [[VAL_3]], [[VAL_4]], NOTYPE : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> +// CHECK-DAG: [[VAL_5:%.*]] = mhlo.compare LT, [[VAL_3]], [[VAL_4]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> // CHECK-DAG: [[VAL_6:%.*]] = mhlo.constant dense<1.600000e+01> : tensor<16x16xf64> -// CHECK-DAG: [[VAL_7:%.*]] = mhlo.compare LT, [[VAL_3]], [[VAL_6]], NOTYPE : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> +// CHECK-DAG: [[VAL_7:%.*]] = mhlo.compare LT, [[VAL_3]], [[VAL_6]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> // CHECK-DAG: [[VAL_8:%.*]] = mhlo.sqrt [[VAL_3]] : tensor<16x16xf64> // CHECK-DAG: [[VAL_9:%.*]] = mhlo.constant dense<3.125000e+00> : tensor<16x16xf64> // CHECK-DAG: [[VAL_10:%.*]] = mhlo.subtract [[VAL_3]], [[VAL_9]] : tensor<16x16xf64> @@ -3587,7 +3587,7 @@ func.func @erf_inv(%arg0 : tensor<16x16xf32>) { // CHECK-DAG: [[VAL_161:%.*]] = mhlo.multiply [[VAL_160]], [[ARG_0]] : tensor<16x16xf64> // CHECK-DAG: [[VAL_162:%.*]] = mhlo.abs [[ARG_0]] : tensor<16x16xf64> // CHECK-DAG: [[VAL_163:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<16x16xf64> -// CHECK-DAG: [[VAL_164:%.*]] = mhlo.compare EQ, [[VAL_162]], [[VAL_163]], NOTYPE : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> +// CHECK-DAG: [[VAL_164:%.*]] = mhlo.compare EQ, [[VAL_162]], [[VAL_163]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> // CHECK-DAG: [[VAL_165:%.*]] = mhlo.constant dense<0x7FF0000000000000> : tensor<16x16xf64> // CHECK-DAG: [[VAL_166:%.*]] = mhlo.multiply [[ARG_0]], [[VAL_165]] : tensor<16x16xf64> // CHECK-DAG: [[VAL_167:%.*]] = mhlo.select [[VAL_164]], [[VAL_166]], [[VAL_161]] : tensor<16x16xi1>, tensor<16x16xf64> From e2c5578d132e6caa0f4ec53c0fee208723f8e7de Mon Sep 17 00:00:00 2001 From: Yifan Jiang Date: Fri, 19 Jul 2024 16:49:37 -0700 Subject: [PATCH 033/376] Pipe through the GPU core_count attribute to GPU Topology PiperOrigin-RevId: 654166866 --- xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 3 ++- xla/pjrt/distributed/topology_util.cc | 3 +++ xla/pjrt/distributed/topology_util_test.cc | 5 +++++ xla/pjrt/gpu/gpu_topology.cc | 4 +++- xla/pjrt/gpu/gpu_topology.h | 12 +++++++++--- xla/pjrt/gpu/gpu_topology.proto | 4 ++++ xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 4 +++- xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc | 17 +++++++++-------- 8 files changed, 38 insertions(+), 14 deletions(-) diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 2d593290087719..8a8e423c4e4dc3 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -193,7 +193,8 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create( device_ids, description.name(), /*num_slices=*/1, /*num_hosts_per_slice=*/1, - /*num_devices_per_host=*/device_ids.size()); + /*num_devices_per_host=*/device_ids.size(), + /*core_count_per_chip=*/description.core_count()); // Determine the platform ID and name based on the platform. xla::PjRtPlatformId platform_id = diff --git a/xla/pjrt/distributed/topology_util.cc b/xla/pjrt/distributed/topology_util.cc index e3926dcb39cd5a..c43186bd8159a8 100644 --- a/xla/pjrt/distributed/topology_util.cc +++ b/xla/pjrt/distributed/topology_util.cc @@ -245,6 +245,9 @@ absl::StatusOr BuildGpuTopology( if (gpu_topology.platform_version().empty()) { gpu_topology.set_platform_version(device.name()); } + if (gpu_topology.core_count_per_chip() == 0) { + gpu_topology.set_core_count_per_chip(device.core_count()); + } slice_id_to_node_ids[device.slice_index()].insert( local_topology.node_id()); device_ids.push_back(device.global_device_id()); diff --git a/xla/pjrt/distributed/topology_util_test.cc b/xla/pjrt/distributed/topology_util_test.cc index f5eaa7952add3c..aaf859c658e157 100644 --- a/xla/pjrt/distributed/topology_util_test.cc +++ b/xla/pjrt/distributed/topology_util_test.cc @@ -99,12 +99,16 @@ TEST(TopologyTest, BuildGpuTopology) { // Adds 2 devices to host 0 and 2 devices to host 1. DeviceProto* d0 = locals[0].add_devices(); d0->set_local_device_ordinal(0); + d0->set_core_count(20); DeviceProto* d1 = locals[0].add_devices(); d1->set_local_device_ordinal(1); + d1->set_core_count(20); DeviceProto* d2 = locals[1].add_devices(); d2->set_local_device_ordinal(0); + d2->set_core_count(20); DeviceProto* d3 = locals[1].add_devices(); d3->set_local_device_ordinal(1); + d3->set_core_count(20); GlobalTopologyProto global = BuildGlobalTopology(absl::Span(locals), @@ -115,6 +119,7 @@ TEST(TopologyTest, BuildGpuTopology) { EXPECT_EQ(gpu_topology.num_slices(), 2); EXPECT_EQ(gpu_topology.num_hosts_per_slice(), 1); EXPECT_EQ(gpu_topology.num_devices_per_host(), 2); + EXPECT_EQ(gpu_topology.core_count_per_chip(), 20); } TEST(TopologyTest, BuildGpuTopologyWithDifferentNumHostsPerSlice) { diff --git a/xla/pjrt/gpu/gpu_topology.cc b/xla/pjrt/gpu/gpu_topology.cc index 600adf98231fcb..3f03de813f42c1 100644 --- a/xla/pjrt/gpu/gpu_topology.cc +++ b/xla/pjrt/gpu/gpu_topology.cc @@ -28,7 +28,8 @@ std::unique_ptr GpuTopology::FromProto( gpu_topology_proto.device_ids().end()}, gpu_topology_proto.platform_version(), gpu_topology_proto.num_slices(), gpu_topology_proto.num_hosts_per_slice(), - gpu_topology_proto.num_devices_per_host()); + gpu_topology_proto.num_devices_per_host(), + gpu_topology_proto.core_count_per_chip()); } GpuTopologyProto GpuTopology::ToProto() const { @@ -38,6 +39,7 @@ GpuTopologyProto GpuTopology::ToProto() const { proto.set_num_slices(num_slices()); proto.set_num_hosts_per_slice(num_hosts_per_slice()); proto.set_num_devices_per_host(num_devices_per_host()); + proto.set_core_count_per_chip(core_count_per_chip()); return proto; } diff --git a/xla/pjrt/gpu/gpu_topology.h b/xla/pjrt/gpu/gpu_topology.h index 957a36001b0968..f21a04bec2c3e6 100644 --- a/xla/pjrt/gpu/gpu_topology.h +++ b/xla/pjrt/gpu/gpu_topology.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_PJRT_GPU_GPU_TOPOLOGY_H_ #define XLA_PJRT_GPU_GPU_TOPOLOGY_H_ +#include #include #include @@ -28,19 +29,22 @@ class GpuTopology { explicit GpuTopology(const std::vector& gpu_device_ids, absl::string_view platform_version, int32_t num_slices, int32_t num_hosts_per_slice, - int32_t num_devices_per_host) + int32_t num_devices_per_host, + int32_t core_count_per_chip) : devices_ids_(gpu_device_ids), platform_version_(platform_version), num_slices_(num_slices), num_hosts_per_slice_(num_hosts_per_slice), - num_devices_per_host_(num_devices_per_host) {} + num_devices_per_host_(num_devices_per_host), + core_count_per_chip_(core_count_per_chip) {} bool operator==(const GpuTopology& other) const { return devices_ids_ == other.devices_ids_ && platform_version_ == other.platform_version_ && num_slices_ == other.num_slices_ && num_hosts_per_slice_ == other.num_hosts_per_slice_ && - num_devices_per_host_ == other.num_devices_per_host_; + num_devices_per_host_ == other.num_devices_per_host_ && + core_count_per_chip_ == other.core_count_per_chip_; } int number_of_devices() const { @@ -60,6 +64,7 @@ class GpuTopology { int32_t num_slices() const { return num_slices_; } int32_t num_hosts_per_slice() const { return num_hosts_per_slice_; } int32_t num_devices_per_host() const { return num_devices_per_host_; } + int32_t core_count_per_chip() const { return core_count_per_chip_; } private: const std::vector devices_ids_; @@ -67,6 +72,7 @@ class GpuTopology { const int32_t num_slices_; const int32_t num_hosts_per_slice_; const int32_t num_devices_per_host_; + const int32_t core_count_per_chip_; bool is_topology_symmetric() const { return num_slices_ != -1 && num_hosts_per_slice_ != -1 && diff --git a/xla/pjrt/gpu/gpu_topology.proto b/xla/pjrt/gpu/gpu_topology.proto index 0bb3c5b34ff62f..bacd3ac4344055 100644 --- a/xla/pjrt/gpu/gpu_topology.proto +++ b/xla/pjrt/gpu/gpu_topology.proto @@ -29,4 +29,8 @@ message GpuTopologyProto { // The number of devices for each host. int32 num_devices_per_host = 6; + + // The number of cores for each device. For Nvidia GPUs, this is the number of + // SMs(Streaming MultiProcessors) on the chip. + int32 core_count_per_chip = 7; } diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index 9c82f9ec9a5a57..c839312edd0909 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -817,13 +817,15 @@ TEST(GpuTopology, ToProto) { /*platform_version=*/"platform_version", /*num_slices=*/2, /*num_hosts_per_slice=*/1, - /*num_devices_per_host=*/3); + /*num_devices_per_host=*/3, + /*core_count_per_chip=*/10); GpuTopologyProto msg = gpu_topology.ToProto(); EXPECT_THAT(msg.device_ids(), ElementsAre(3, 2, 1)); EXPECT_THAT(msg.platform_version(), "platform_version"); EXPECT_THAT(msg.num_slices(), 2); EXPECT_THAT(msg.num_hosts_per_slice(), 1); EXPECT_THAT(msg.num_devices_per_host(), 3); + EXPECT_THAT(msg.core_count_per_chip(), 10); } TEST(StreamExecutorGpuClientTest, DistributedInit) { diff --git a/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc index 6d29d51572083e..399056991d5c8d 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc @@ -63,16 +63,17 @@ absl::StatusOr GetXlaComputation( std::shared_ptr GetGpuTopology( std::vector device_ids, absl::string_view platform_version, - int num_slices, int num_hosts_per_slice, int num_devices_per_host) { - return std::make_shared(device_ids, platform_version, - num_slices, num_hosts_per_slice, - num_devices_per_host); + int num_slices, int num_hosts_per_slice, int num_devices_per_host, + int core_count_per_chip) { + return std::make_shared( + device_ids, platform_version, num_slices, num_hosts_per_slice, + num_devices_per_host, core_count_per_chip); } TEST(StreamExecutorGpuCompilerTest, NoClientXla) { StreamExecutorGpuCompiler compiler; StreamExecutorGpuTopologyDescription topology( - CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2)); + CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2, 10)); TF_ASSERT_OK_AND_ASSIGN(auto computation, GetXlaComputation(kProgram)); EXPECT_THAT(compiler.Compile(xla::CompileOptions(), computation, topology, @@ -83,7 +84,7 @@ TEST(StreamExecutorGpuCompilerTest, NoClientXla) { TEST(StreamExecutorGpuCompilerTest, TopologyNotSameXla) { StreamExecutorGpuCompiler compiler; StreamExecutorGpuTopologyDescription topology( - CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2)); + CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2, 10)); TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -129,7 +130,7 @@ TEST(StreamExecutorGpuCompilerTest, NoClientMlir) { mlir::parseSourceString(mlir_str, &context); StreamExecutorGpuTopologyDescription topology( - CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2)); + CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2, 10)); EXPECT_THAT( compiler.Compile(xla::CompileOptions(), mlir_module.get(), topology, @@ -147,7 +148,7 @@ TEST(StreamExecutorGpuCompilerTest, TopologyNotSameMlir) { mlir::parseSourceString(mlir_str, &context); StreamExecutorGpuTopologyDescription topology( - CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2)); + CudaId(), CudaName(), GetGpuTopology({0, 1}, "Fake_device", 1, 1, 2, 10)); TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); From 4a3d464ce0c5fec3a7e48e2230b6478c854b6aaa Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Fri, 19 Jul 2024 17:08:55 -0700 Subject: [PATCH 034/376] Move macros in `third_party/mkl_dnn/build_defs.bzl` to `tsl/mkl/build_defs.bzl` PiperOrigin-RevId: 654171036 --- third_party/tsl/opensource_only.files | 1 - third_party/tsl/third_party/mkl_dnn/BUILD | 7 ---- .../tsl/third_party/mkl_dnn/build_defs.bzl | 34 ------------------- .../tsl/third_party/mkl_dnn/mkldnn_v1.BUILD | 3 +- xla/tsl/BUILD | 1 - xla/tsl/mkl/build_defs.bzl | 33 +++++++++++++++++- xla/tsl/tsl.bzl | 9 ++--- 7 files changed, 36 insertions(+), 52 deletions(-) delete mode 100644 third_party/tsl/third_party/mkl_dnn/build_defs.bzl diff --git a/third_party/tsl/opensource_only.files b/third_party/tsl/opensource_only.files index 821ebe49968eea..300ae95c10aec2 100644 --- a/third_party/tsl/opensource_only.files +++ b/third_party/tsl/opensource_only.files @@ -59,7 +59,6 @@ third_party/llvm_openmp/expand_cmake_vars:.py third_party/llvm_openmp/openmp.bzl: third_party/mkl/BUILD: third_party/mkl_dnn/LICENSE: -third_party/mkl_dnn/build_defs.bzl: third_party/mkl_dnn/mkldnn_acl.BUILD: third_party/mkl_dnn/mkldnn_v1.BUILD: third_party/nccl/BUILD: diff --git a/third_party/tsl/third_party/mkl_dnn/BUILD b/third_party/tsl/third_party/mkl_dnn/BUILD index d346ce45cba1b6..99ee10bb8354a2 100644 --- a/third_party/tsl/third_party/mkl_dnn/BUILD +++ b/third_party/tsl/third_party/mkl_dnn/BUILD @@ -1,5 +1,3 @@ -load("@bazel_skylib//:bzl_library.bzl", "bzl_library") - package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], @@ -10,8 +8,3 @@ exports_files( ["LICENSE"], visibility = ["//visibility:public"], ) - -bzl_library( - name = "build_defs_bzl", - srcs = ["build_defs.bzl"], -) diff --git a/third_party/tsl/third_party/mkl_dnn/build_defs.bzl b/third_party/tsl/third_party/mkl_dnn/build_defs.bzl deleted file mode 100644 index c741c850741b49..00000000000000 --- a/third_party/tsl/third_party/mkl_dnn/build_defs.bzl +++ /dev/null @@ -1,34 +0,0 @@ -"""Starlark macros for oneDNN. - -if_mkldnn_openmp checks if we are building x86 backend with OpenMP. -if_mkldnn_aarch64_acl checks if we are building with Arm Compute Library. -if_mkldnn_aarch64_acl_openmp checks if we are building ACL with OpenMP. -""" - -def if_mkldnn_openmp(if_true, if_false = []): - """Returns `if_true` if OpenMP is used with oneDNN. - - Shorthand for select()'ing on whether we're building with - oneDNN open source library only with openmp - - Returns a select statement which evaluates to if_true if we're building - with oneDNN open source library only with OpenMP. Otherwise, the - select statement evaluates to if_false. - - """ - return select({ - "@xla//xla/tsl/mkl:build_with_mkldnn_openmp": if_true, - "//conditions:default": if_false, - }) - -def if_mkldnn_aarch64_acl(if_true, if_false = []): - return select({ - "@xla//xla/tsl/mkl:build_with_mkl_aarch64": if_true, - "//conditions:default": if_false, - }) - -def if_mkldnn_aarch64_acl_openmp(if_true, if_false = []): - return select({ - "@xla//xla/tsl/mkl:build_with_mkl_aarch64_openmp": if_true, - "//conditions:default": if_false, - }) diff --git a/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD index ed6d04c282d00c..bb881769156d88 100644 --- a/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -1,7 +1,6 @@ load("@bazel_skylib//rules:expand_template.bzl", "expand_template") -load("@tsl//third_party/mkl_dnn:build_defs.bzl", "if_mkldnn_openmp") load("@xla//xla/tsl:tsl.bzl", "tf_openmp_copts") -load("@xla//xla/tsl/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml") +load("@xla//xla/tsl/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml", "if_mkldnn_openmp") exports_files(["LICENSE"]) diff --git a/xla/tsl/BUILD b/xla/tsl/BUILD index 3c78837b9b1133..9d47da8526a1fb 100644 --- a/xla/tsl/BUILD +++ b/xla/tsl/BUILD @@ -518,7 +518,6 @@ bzl_library( "@local_config_cuda//cuda:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl", - "@tsl//third_party/mkl_dnn:build_defs_bzl", "@tsl//tsl/platform:rules_cc_bzl", ], ) diff --git a/xla/tsl/mkl/build_defs.bzl b/xla/tsl/mkl/build_defs.bzl index 56a19e4713da24..4054a6b21add3a 100644 --- a/xla/tsl/mkl/build_defs.bzl +++ b/xla/tsl/mkl/build_defs.bzl @@ -1,10 +1,13 @@ -"""Starlark macros for MKL. +"""Starlark macros for MKL and oneDNN. if_mkl is a conditional to check if we are building with MKL. if_mkl_ml is a conditional to check if we are building with MKL-ML. if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode. if_mkl_lnx_x64 is a conditional to check for MKL if_enable_mkl is a conditional to check if building with MKL and MKL is enabled. +if_mkldnn_openmp checks if we are building x86 backend with OpenMP. +if_mkldnn_aarch64_acl checks if we are building with Arm Compute Library. +if_mkldnn_aarch64_acl_openmp checks if we are building ACL with OpenMP. mkl_repository is a repository rule for creating MKL repository rule that can be pointed to either a local folder, or download it from the internet. @@ -121,6 +124,34 @@ def onednn_v3_define(): "//conditions:default": [], }) +def if_mkldnn_openmp(if_true, if_false = []): + """Returns `if_true` if OpenMP is used with oneDNN. + + Shorthand for select()'ing on whether we're building with + oneDNN open source library only with openmp + + Returns a select statement which evaluates to if_true if we're building + with oneDNN open source library only with OpenMP. Otherwise, the + select statement evaluates to if_false. + + """ + return select({ + "@xla//xla/tsl/mkl:build_with_mkldnn_openmp": if_true, + "//conditions:default": if_false, + }) + +def if_mkldnn_aarch64_acl(if_true, if_false = []): + return select({ + "@xla//xla/tsl/mkl:build_with_mkl_aarch64": if_true, + "//conditions:default": if_false, + }) + +def if_mkldnn_aarch64_acl_openmp(if_true, if_false = []): + return select({ + "@xla//xla/tsl/mkl:build_with_mkl_aarch64_openmp": if_true, + "//conditions:default": if_false, + }) + def _enable_local_mkl(repository_ctx): return _TF_MKL_ROOT in repository_ctx.os.environ diff --git a/xla/tsl/tsl.bzl b/xla/tsl/tsl.bzl index 42fc27e357ec48..f81dd65aa065a5 100644 --- a/xla/tsl/tsl.bzl +++ b/xla/tsl/tsl.bzl @@ -9,18 +9,15 @@ load( "//xla/tsl/mkl:build_defs.bzl", "if_enable_mkl", "if_mkl", + "if_mkldnn_aarch64_acl", + "if_mkldnn_aarch64_acl_openmp", + "if_mkldnn_openmp", "onednn_v3_define", ) load( "//third_party/compute_library:build_defs.bzl", "if_enable_acl", ) -load( - "@tsl//third_party/mkl_dnn:build_defs.bzl", - "if_mkldnn_aarch64_acl", - "if_mkldnn_aarch64_acl_openmp", - "if_mkldnn_openmp", -) load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm", From d309580d16ef74602118c3c810c5d115c017e433 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Fri, 19 Jul 2024 18:13:49 -0700 Subject: [PATCH 035/376] Fix acos decomposition for non-complex arguments The previous decomposition was incorrect for x == -1, which should return pi. PiperOrigin-RevId: 654183025 --- third_party/stablehlo/temporary.patch | 89 +++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index e304a133e9af3a..a6da7f82c42f12 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -10,4 +10,93 @@ diff --ruN a/stablehlo/stablehlo/reference/Tensor.cpp b/stablehlo/stablehlo/refe os << "}"; } +diff --ruN a/stablehlo/stablehlo/tests/math/acos_limits.mlir b/stablehlo/stablehlo/tests/math/acos_limits.mlir +--- stablehlo/stablehlo/tests/math/acos_limits.mlir ++++ stablehlo/stablehlo/tests/math/acos_limits.mlir +@@ -0,0 +1,14 @@ ++// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret ++ ++func.func @main() -> (tensor, tensor>) { ++ %cst = stablehlo.constant dense<-1.000000e+00> : tensor ++ %cst_0 = stablehlo.constant dense<(-1.000000e+00,0.000000e+00)> : tensor> ++ %zero = stablehlo.constant dense<0.0> : tensor ++ %pi = stablehlo.constant dense<3.1415926535897931> : tensor ++ %complex_pi = stablehlo.complex %pi, %zero : tensor> ++ %0 = chlo.acos %cst : tensor -> tensor ++ %1 = chlo.acos %cst_0 : tensor> -> tensor> ++ check.expect_close %0, %pi, max_ulp_difference = 1 : tensor, tensor ++ check.expect_close %1, %complex_pi, max_ulp_difference = 1 : tensor>, tensor> ++ return %0, %1 : tensor, tensor> ++} +diff --ruN a/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td b/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td +--- stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td ++++ stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td +@@ -45,6 +45,37 @@ + //===----------------------------------------------------------------------===// + // Unary op patterns. + //===----------------------------------------------------------------------===// ++ ++// Expand acos for non-complex arguments to MHLO dialect as follows: ++// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 ++// = pi if x == -1 ++// ++// Note: Complex decomposition is in ChloDecompositionPatternsMath.td ++def : Pat<(CHLO_AcosOp NonComplexElementType:$input), ++ (StableHLO_SelectOp ++ (StableHLO_CompareOp ++ $input, ++ (StableHLO_ConstantLike<"-1"> $input), ++ StableHLO_ComparisonDirectionValue<"NE">, ++ (STABLEHLO_DEFAULT_COMPARISON_TYPE) ++ ), ++ (StableHLO_MulOp ++ (StableHLO_ConstantLike<"2"> $input), ++ (StableHLO_Atan2Op ++ (StableHLO_SqrtOp ++ (StableHLO_SubtractOp ++ (StableHLO_ConstantLike<"1"> $input), ++ (StableHLO_MulOp $input, $input) ++ ) ++ ), ++ (StableHLO_AddOp ++ (StableHLO_ConstantLike<"1"> $input), ++ $input ++ ) ++ ) ++ ), ++ (StableHLO_ConstantLike<"M_PI"> $input) ++ )>; + + // Express `atan` as + // atan(x) = atan2(x, 1) +diff --ruN a/stablehlo/stablehlo/transforms/ChloDecompositionPatternsMath.td b/stablehlo/stablehlo/transforms/ChloDecompositionPatternsMath.td +--- stablehlo/stablehlo/transforms/ChloDecompositionPatternsMath.td ++++ stablehlo/stablehlo/transforms/ChloDecompositionPatternsMath.td +@@ -634,26 +634,6 @@ + (StableHLO_Log1pOp + (StableHLO_AddOp $am1, $sq)))), + (StableHLO_NegOp $imag)))>; +- +-// Arcus cosine on real input: +-// +-// arccos(x) = 2 * arctan2(sqrt(1 - x * x), 1 + x) +-// +-// To avoid cancellation errors at abs(x) close to 1, we'll use +-// +-// 1 - x * x == (1 - x) * (1 + x) +-// +-def : Pat<(CHLO_AcosOp NonComplexElementType:$x), +- (StableHLO_MulOp +- (StableHLO_ConstantLike<"2"> $x), +- (StableHLO_Atan2Op +- (StableHLO_SqrtOp +- (StableHLO_MulOp +- (StableHLO_SubtractOp +- (StableHLO_ConstantLike<"1">:$one $x), +- $x), +- (StableHLO_AddOp:$add_one_x $one, $x))), +- $add_one_x))>; + + // Inverse hyperbolic cosine on complex input: + // From 518c087b32e105cc7343693debd3d4d389b9b743 Mon Sep 17 00:00:00 2001 From: Amit Sabne Date: Fri, 19 Jul 2024 20:58:57 -0700 Subject: [PATCH 036/376] Improve int2 support PiperOrigin-RevId: 654218772 --- xla/hlo/evaluator/BUILD | 1 + .../evaluator/hlo_evaluator_typed_visitor.h | 2 ++ .../hlo_evaluator_typed_visitor_int2.cc | 23 +++++++++++++++++++ 3 files changed, 26 insertions(+) create mode 100644 xla/hlo/evaluator/hlo_evaluator_typed_visitor_int2.cc diff --git a/xla/hlo/evaluator/BUILD b/xla/hlo/evaluator/BUILD index c622fdc4686c4e..dc7407cf9ce4ce 100644 --- a/xla/hlo/evaluator/BUILD +++ b/xla/hlo/evaluator/BUILD @@ -31,6 +31,7 @@ cc_library( "hlo_evaluator_typed_visitor_float8.cc", "hlo_evaluator_typed_visitor_half.cc", "hlo_evaluator_typed_visitor_int16.cc", + "hlo_evaluator_typed_visitor_int2.cc", "hlo_evaluator_typed_visitor_int32.cc", "hlo_evaluator_typed_visitor_int4.cc", "hlo_evaluator_typed_visitor_int64.cc", diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 78ac5b7e7c0947..f9f460bfae8df3 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1710,11 +1710,13 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { // instantiating it. We explicitly instantiate this class in the various // hlo_evaluator_typed_visitor*.cc files. extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int2.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int2.cc new file mode 100644 index 00000000000000..24ba8714f5922a --- /dev/null +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int2.cc @@ -0,0 +1,23 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "xla/types.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; +} // namespace xla From 0c848bb9ed03d8c2c1c19d5eef851c71ad3b3cfd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 19 Jul 2024 21:57:01 -0700 Subject: [PATCH 037/376] [XLA] Introduce selective resources in latency hiding scheduler. Introduce a new resource hazard type in latency hiding scheduler that informs the scheduler of async instructions that can only make progress during the execution of a limited set of instructions. PiperOrigin-RevId: 654227259 --- xla/service/BUILD | 1 + xla/service/latency_hiding_scheduler.cc | 45 +++++++ xla/service/latency_hiding_scheduler.h | 32 ++++- xla/service/latency_hiding_scheduler_test.cc | 133 ++++++++++++++++++- 4 files changed, 206 insertions(+), 5 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 73db9ba1a81a21..8857ab34bb12c9 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1464,6 +1464,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index 8b7b43322e1dc3..c7bdae2f620db4 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -477,6 +477,15 @@ AsyncTracker::GetReleasedNonextendableResourcesFromVector( return {}; } +bool AsyncTracker::ReleasesSelectiveResource(const HloGraphNode* node) const { + return absl::c_any_of( + node->GetResources(), [&](const ResourcePair& resource) { + return resource.second == ResourceUsageType::kResourceRelease && + GetResourceHazardType(resource.first) == + ResourceHazardType::kSelective; + }); +} + BufferInfoTracker::BufferInfoTracker( const HloModule* module, const HloAliasAnalysis* alias_analysis, const HloCostAnalysis::ShapeSizeFunction& shape_size_bytes) { @@ -1403,6 +1412,33 @@ absl::StatusOr DefaultSchedulerCore::ScheduleNode( sched_state->new_sequence_reversed.push_back( const_cast(&n->GetInstr())); n->SetScheduled(); + + // Remove scheduled node from selective_resources_releasers if it + // was there. + if (sched_state->config.enable_selective_resources && + n->ReleasesSelectiveResource()) { + auto it = std::find(sched_state->selective_resource_releasers.begin(), + sched_state->selective_resource_releasers.end(), n); + // Perform sanity check node was in selective_resources_releasers. + if (it == sched_state->selective_resource_releasers.end()) { + LOG(WARNING) << "Selective resource releasers list does not contain node " + "that releases a selective resource: " + << n->ToString(); + } else { + sched_state->selective_resource_releasers.erase(it); + } + } + + // If scheduled node cannot overlap with nodes that hold selective resources, + // we increment the ready time of all nodes that release a selective resource + // with the cost of the scheduled node. + if (sched_state->config.enable_selective_resources && + !n->GetValuableForSelectiveOverlap()) { + for (HloGraphNode* node : sched_state->selective_resource_releasers) { + node->SetReadyTime(node->GetReadyTime() + n->GetCost()); + } + } + // If this node is an async start/done handle the increase/decrease the number // of outstanding async ops. for (auto& resource : n->GetResources()) { @@ -1541,6 +1577,13 @@ absl::StatusOr DefaultSchedulerCore::ScheduleNode( std::push_heap(sched_state->next_ready_stack.begin(), sched_state->next_ready_stack.end(), ready_time_cmp); } + + // If the node we added to ready set releases a selective resource, add + // it to selective_resources_releasers. + if (sched_state->config.enable_selective_resources && + edge.Target().ReleasesSelectiveResource()) { + sched_state->selective_resource_releasers.push_back(&edge.Target()); + } } ++sched_state->scheduled_count; for (auto& resource : n->GetResources()) { @@ -1619,6 +1662,8 @@ HloScheduleGraph::HloScheduleGraph( new_node_it->second->occupied_shareable_resources_ = async_tracker->GetOccupiedShareableResourcesFromVector( new_node_it->second->GetResources()); + new_node_it->second->releases_selective_resource_ = + async_tracker->ReleasesSelectiveResource(new_node_it->second.get()); // Gather while instructions for subsequent send-done dependency checks. if (instr->opcode() == HloOpcode::kWhile) { while_instrs.push_back(instr); diff --git a/xla/service/latency_hiding_scheduler.h b/xla/service/latency_hiding_scheduler.h index b04c746280833e..76ce8b307f7184 100644 --- a/xla/service/latency_hiding_scheduler.h +++ b/xla/service/latency_hiding_scheduler.h @@ -94,7 +94,10 @@ enum class ResourceHazardType { // past. This hazard type is useful to prevent increasing such ops' overlaps // more than necessary. kNonextendable = 2, - kUnshareable = 3, + // Ops holding this resource can only have their latency/cost covered by + // ops that are valuable for selective overlap. + kSelective = 3, + kUnshareable = 4, }; constexpr int64_t ResourceTypeToIndex(ResourceType resource_type) { @@ -133,6 +136,7 @@ struct SchedulerConfig { bool resource_sharing = false; bool resource_serializing = false; bool depth_based_memory_pressure_reduction = false; + bool enable_selective_resources = false; int64_t rerun = 0; }; @@ -277,6 +281,9 @@ class AsyncTracker { GetReleasedNonextendableResourcesFromVector( const ResourcesVector& resources) const; + // Returns whether the provided node releases a selective resource. + bool ReleasesSelectiveResource(const HloGraphNode* node) const; + inline CanonicalAsyncOp GetCanonicalAsyncOp(const HloInstruction& hlo) const { return get_canonical_async_op_(hlo); } @@ -284,16 +291,16 @@ class AsyncTracker { explicit AsyncTracker( const SchedulerConfig& config, GetCanonicalAsyncOpFunc func = DefaultGetCanonicalAsyncOp) - : config_(config), get_canonical_async_op_(func) {} + : get_canonical_async_op_(std::move(func)), config_(config) {} private: - const SchedulerConfig config_; mutable absl::flat_hash_map> async_in_computation_cache_; GetCanonicalAsyncOpFunc get_canonical_async_op_; protected: + const SchedulerConfig config_; mutable absl::flat_hash_map resources_cache_; }; @@ -370,6 +377,16 @@ class HloGraphNode { void SetForceDelay(bool force_delay) { force_delay_ = force_delay; } bool GetForceEarly() const { return force_early_; } void SetForceEarly(bool force_early) { force_early_ = force_early; } + bool GetValuableForSelectiveOverlap() const { + return valuable_for_selective_overlap_; + } + void SetValuableForSelectiveOverlap(bool valuable_for_selective_overlap) { + valuable_for_selective_overlap_ = valuable_for_selective_overlap; + } + bool ReleasesSelectiveResource() const { + return releases_selective_resource_; + } + ResourcesVector GetResources() const { return resources_; } bool DoesOccupyAnyResource() const { return absl::c_any_of(resources_, [](const ResourcePair& resource) { @@ -503,6 +520,11 @@ class HloGraphNode { absl::InlinedVector released_shareable_resources_; // Shareable resources occupied by this node. absl::InlinedVector occupied_shareable_resources_; + // Whether this node can be overlapped with (can cover the latency/cost of) + // edges occupying selective resources. + bool valuable_for_selective_overlap_ = true; + // Whether this node releases a selective resource. + bool releases_selective_resource_ = false; }; // Schedule graph that can be used to drive scheduling @@ -833,6 +855,9 @@ class DefaultSchedulerCore : public SchedulerCore { absl::flat_hash_map< int64_t, std::vector>> shareable_resource_occupiers; + // List of the graph nodes that release selective resources. + std::vector selective_resource_releasers; + // Reference to this scheduler run configuration. const SchedulerConfig& config; SchedulingState(const HloInstructionSequence* instr_sequence, @@ -895,6 +920,7 @@ class DefaultSchedulerCore : public SchedulerCore { virtual absl::StatusOr FindAndExtractBestNodeAvailable( SchedulingState& sched_state, DefaultSchedulerCore::ShouldSkipNodeFunction should_skip_node); + bool DoesNodeReleaseSelectiveResource(const HloGraphNode* node) const; void DumpLatencyHidingSchedule( const HloComputation* computation, const HloScheduleGraph& schedule_graph, const std::vector& instructions, diff --git a/xla/service/latency_hiding_scheduler_test.cc b/xla/service/latency_hiding_scheduler_test.cc index f73f6470519510..60b00e642eb10a 100644 --- a/xla/service/latency_hiding_scheduler_test.cc +++ b/xla/service/latency_hiding_scheduler_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" @@ -141,7 +142,8 @@ class TestLatencyEstimator : public LatencyEstimator { absl::StatusOr RunScheduler( HloModule* module, SchedulerConfig sched_config = GetDefaultSchedConfig(), std::unique_ptr latency_estimator = - std::make_unique()) { + std::make_unique(), + std::unique_ptr async_tracker = nullptr) { AsyncCollectiveCreator::CollectiveCreatorConfig config{ /*convert_all_reduce=*/HloPredicateTrue, /*convert_all_gather=*/HloPredicateTrue, @@ -160,7 +162,9 @@ absl::StatusOr RunScheduler( } return ShapeUtil::ByteSizeOfElements(shape); }; - auto async_tracker = std::make_unique(sched_config); + if (!async_tracker) { + async_tracker = std::make_unique(sched_config); + } auto scheduler_core = std::make_unique( shape_size_bytes, async_tracker.get(), latency_estimator.get(), sched_config); @@ -3110,4 +3114,129 @@ ENTRY %entry { GetIndex(new_instruction_sequence, "while")); } +// This test simulates a sample target where all-gathers contain non-extendable +// and selective resources. +TEST_F(LatencyHidingSchedulerTest, AllGatherWithSelectiveOverlap) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY %module { + %constant.19 = u32[] constant(0) + %replica_id = u32[]{:T(128)} replica-id() + %convert = f32[]{:T(128)} convert(u32[]{:T(128)} %replica_id) + %color_operand.1 = f32[8,256,256]{2,1,0:T(8,128)} broadcast( + f32[]{:T(128)} %convert), dimensions={} + %ag-start = (f32[8,256,256], f32[16,256,256]) all-gather-start( + f32[8,256,256] %color_operand.1), replica_groups={{0,1}}, dimensions={0}, + metadata={op_type="AllGather" op_name="ag0"} + %ag-done = f32[16,256,256] all-gather-done( + (f32[8,256,256], f32[16,256,256]) %ag-start), + metadata={op_type="AllGather" op_name="ag0"} + p0 = f32[16,64,256]{2,1,0} parameter(0) + p1 = f32[16,64,256]{2,1,0} parameter(1) + p2 = f32[16,256,256]{2,1,0} parameter(2) + p3 = f32[16,256,256]{2,1,0} parameter(3) + c0 = f32[16,256,256]{2,1,0} convolution(p0, p1), + window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb + c1 = f32[16,256,256]{2,1,0} convolution(p0, p1), + window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb + c2 = f32[16,256,256]{2,1,0} convolution(p0, p1), + window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb + ROOT a2 = f32[16,256,256]{2,1,0} add(%ag-done, c0) +} +)"; + + // Extend AsyncTracker for a fake target where all-gather contains + // non-extendable and selective resources. + class SelectiveOverlapAsyncTracker : public AsyncTracker { + public: + explicit SelectiveOverlapAsyncTracker(const SchedulerConfig& sched_config) + : AsyncTracker(sched_config) {} + + ResourceHazardType GetResourceHazardType( + int64_t resource_type) const override { + if (resource_type == ResourceTypeToIndex(ResourceType::kAllGather)) { + return ResourceHazardType::kSelective; + } + // The first target defined resource is defined as non-extendable. + if (resource_type == AsyncTracker::GetFirstTargetDefinedResource()) { + return ResourceHazardType::kNonextendable; + } + return AsyncTracker::GetResourceHazardType(resource_type); + } + + ResourcesVector GetResourcesFromInstruction( + const HloInstruction& hlo) const override { + ResourcesVector result = AsyncTracker::GetResourcesFromInstruction(hlo); + // There is only one target defined resource (which is non-extendable). + if (hlo.opcode() == HloOpcode::kAllGatherStart) { + result.push_back({AsyncTracker::GetFirstTargetDefinedResource(), + ResourceUsageType::kResourceRelease}); + } + return result; + } + + absl::InlinedVector GetReleasedNonextendableResourcesFromVector( + const ResourcesVector& resources) const override { + absl::InlinedVector non_extendable_resources; + for (const ResourcePair& resource : resources) { + if (GetResourceHazardType(resource.first) == + ResourceHazardType::kNonextendable) { + non_extendable_resources.push_back({resource.first}); + } + } + return non_extendable_resources; + } + + void PostProcessScheduleGraph( + HloScheduleGraph* schedule_graph, + const LatencyEstimator* latency_estimator) const override { + // Mark c2 as not valuable for selective overlap. + for (const HloInstruction* instr : + schedule_graph->GetOriginalInstrList()) { + if (instr->name() == "c2") { + schedule_graph->GetNode(instr).SetValuableForSelectiveOverlap(false); + } + } + } + }; + SchedulerConfig sched_config = GetDefaultSchedConfig(); + sched_config.enable_selective_resources = true; + std::unique_ptr async_tracker = + std::make_unique(sched_config); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + HloComputation* entry_computation = hlo_module->entry_computation(); + std::vector original_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config, + std::make_unique(), + std::move(async_tracker)) + .ok()); + std::vector new_instruction_sequence = + module_schedule.sequence(entry_computation).instructions(); + + // Without selective async tracker, we would expect all-gather to only overlap + // with c2 as c2 has a cost of 5000 which fully covers the latency of + // all-gather. However, as c2 is not valuable for selective overlap, we expect + // all-gather overlap with c1 and c2 (c2 is effectively ignored from a latency + // hiding perspective). + if (VLOG_IS_ON(1)) { + for (auto* new_i : new_instruction_sequence) { + VLOG(1) << new_i->ToString(); + } + } + int c0_index = GetIndex(new_instruction_sequence, "c0"); + int c1_index = GetIndex(new_instruction_sequence, "c1"); + int c2_index = GetIndex(new_instruction_sequence, "c2"); + int ag_start_index = GetIndex(new_instruction_sequence, "ag-start"); + int ag_done_index = GetIndex(new_instruction_sequence, "ag-done"); + EXPECT_LT(c0_index, ag_start_index); + EXPECT_LT(ag_start_index, c1_index); + EXPECT_LT(c1_index, c2_index); + EXPECT_LT(c2_index, ag_done_index); +} + } // namespace xla From d669e2bcd6015b37ec158dfd14beb06ecbb3f19e Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Sat, 20 Jul 2024 02:44:41 -0700 Subject: [PATCH 038/376] [XLA:GPU] Add A100-80 autotuning entries for the gpu_compiler_test. PiperOrigin-RevId: 654271449 --- .../gpu_compiler_test_autotune_db.textproto | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto index ecdc8e089ca80a..3549c95f7fdccd 100644 --- a/xla/service/gpu/gpu_compiler_test_autotune_db.textproto +++ b/xla/service/gpu/gpu_compiler_test_autotune_db.textproto @@ -23,6 +23,30 @@ results { } } } +results { + device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 40 MB" + hlo: "{\n tmp_0 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(0)\n tmp_1 = f32[1,4,32,1024,1024]{4,3,2,1,0} convert(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_0)\n tmp_2 = bf16[] constant({...})\n tmp_3 = bf16[1,4,32,1024,1024]{4,3,2,1,0} broadcast(bf16[] tmp_2), dimensions={}\n tmp_4 = f32[1,4,32,1024,1024]{4,3,2,1,0} convert(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_3)\n tmp_5 = f32[1,4,32,1024,1024]{4,3,2,1,0} multiply(f32[1,4,32,1024,1024]{4,3,2,1,0} tmp_1, f32[1,4,32,1024,1024]{4,3,2,1,0} tmp_4)\n tmp_6 = bf16[1,4,32,1024,1024]{4,3,2,1,0} convert(f32[1,4,32,1024,1024]{4,3,2,1,0} tmp_5)\n tmp_7 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_6)\n tmp_8 = bf16[4,32,1024,1024]{3,2,1,0} transpose(bf16[4,32,1024,1024]{3,2,1,0} tmp_7), dimensions={0,1,3,2}\n tmp_9 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[4,32,1024,1024]{3,2,1,0} tmp_8)\n tmp_10 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(1)\n tmp_11 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_10)\n tmp_12 = bf16[128,1024,1024]{2,1,0} dot(bf16[128,1024,1024]{2,1,0} tmp_9, bf16[128,1024,1024]{2,1,0} tmp_11), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}\n ROOT tmp_13 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[128,1024,1024]{2,1,0} tmp_12)\n}" + result { + gemm { + algorithm: -1 + } + run_time { + nanos: 1 + } + } +} +results { + device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 40 MB" + hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}" + result { + run_time { + nanos: 1 + } + gemm { + algorithm: -1 + } + } +} results { device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" hlo: "{\n tmp_0 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(0)\n tmp_1 = bf16[] constant({...})\n tmp_2 = bf16[1,4,32,1024,1024]{4,3,2,1,0} broadcast(bf16[] tmp_1), dimensions={}\n tmp_3 = bf16[1,4,32,1024,1024]{4,3,2,1,0} multiply(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_0, bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_2)\n tmp_4 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_3)\n tmp_5 = bf16[4,32,1024,1024]{3,2,1,0} transpose(bf16[4,32,1024,1024]{3,2,1,0} tmp_4), dimensions={0,1,3,2}\n tmp_6 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[4,32,1024,1024]{3,2,1,0} tmp_5)\n tmp_7 = bf16[1,4,32,1024,1024]{4,3,2,1,0} parameter(1)\n tmp_8 = bf16[128,1024,1024]{2,1,0} bitcast(bf16[1,4,32,1024,1024]{4,3,2,1,0} tmp_7)\n tmp_9 = bf16[128,1024,1024]{2,1,0} dot(bf16[128,1024,1024]{2,1,0} tmp_6, bf16[128,1024,1024]{2,1,0} tmp_8), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}\n ROOT tmp_10 = bf16[4,32,1024,1024]{3,2,1,0} bitcast(bf16[128,1024,1024]{2,1,0} tmp_9)\n}" From a01a3db2e73f1477ad1d7c4d9c92863c52538ef5 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Sat, 20 Jul 2024 10:14:48 -0700 Subject: [PATCH 039/376] Use new Stream::CreateEventBasedTimer interface to create GpuTimers in stream_executor directories. PiperOrigin-RevId: 654326402 --- xla/stream_executor/cuda/BUILD | 4 +- xla/stream_executor/cuda/cuda_blas.cc | 32 ++++++----- xla/stream_executor/cuda/cuda_blas_lt.cc | 9 ++-- xla/stream_executor/cuda/cuda_dnn.cc | 68 +++++++++++------------- xla/stream_executor/rocm/BUILD | 6 +-- xla/stream_executor/rocm/hip_blas_lt.cc | 9 ++-- xla/stream_executor/rocm/rocm_blas.cc | 20 +++---- xla/stream_executor/rocm/rocm_dnn.cc | 42 +++++++-------- 8 files changed, 89 insertions(+), 101 deletions(-) diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 03b5db71e0deeb..161c13f2641bb9 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -292,6 +292,7 @@ cuda_only_cc_library( "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:device_memory", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor:numeric_options", "//xla/stream_executor:plugin_registry", @@ -301,7 +302,6 @@ cuda_only_cc_library( "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/platform", "//xla/tsl/cuda:cublas", @@ -409,6 +409,7 @@ cuda_only_cc_library( "//xla/stream_executor", "//xla/stream_executor:data_type", "//xla/stream_executor:dnn", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:numeric_options", "//xla/stream_executor:plugin_registry", "//xla/stream_executor:scratch_allocator", @@ -418,7 +419,6 @@ cuda_only_cc_library( "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream", - "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/platform", "//xla/tsl/cuda:cudnn", "//xla/tsl/util:env_var", diff --git a/xla/stream_executor/cuda/cuda_blas.cc b/xla/stream_executor/cuda/cuda_blas.cc index 79395f1858a205..c516829cbfff40 100644 --- a/xla/stream_executor/cuda/cuda_blas.cc +++ b/xla/stream_executor/cuda/cuda_blas.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include #include -#include +#include #include #include @@ -43,11 +43,11 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/platform/port.h" @@ -63,11 +63,9 @@ limitations under the License. namespace stream_executor { namespace cuda { -using gpu::AsGpuStream; using gpu::AsGpuStreamValue; using gpu::GpuMemory; using gpu::GpuMemoryMutable; -using gpu::GpuTimer; // cuBLAS has interfaces that permit pointers to be passed from either the host // memory space or the device memory space; however, you must instruct it as to @@ -694,7 +692,7 @@ static absl::StatusOr GetMathTypeForGemmEx( } static absl::Status PopulateProfileFromTimer( - std::optional &timer, blas::AlgorithmType algorithm, + EventBasedTimer *timer, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { if (output_profile_result) { TF_ASSIGN_OR_RETURN(absl::Duration duration, timer->GetElapsedDuration()); @@ -718,11 +716,11 @@ absl::Status CUDABlas::DoBlasGemmWithAlgorithm( cublasMath_t math_type, GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options)); - std::optional timer = std::nullopt; + std::unique_ptr timer; if (output_profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, - GpuTimer::Create(stream, output_profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, + stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast, @@ -737,7 +735,7 @@ absl::Status CUDABlas::DoBlasGemmWithAlgorithm( ldc, AsCublasComputeType(computation_type), static_cast(algorithm))); TF_RETURN_IF_ERROR( - PopulateProfileFromTimer(timer, algorithm, output_profile_result)); + PopulateProfileFromTimer(timer.get(), algorithm, output_profile_result)); return absl::OkStatus(); } @@ -753,11 +751,11 @@ absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( TF_ASSIGN_OR_RETURN( cublasMath_t math_type, GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options)); - std::optional timer = std::nullopt; + std::unique_ptr timer; if (output_profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, - GpuTimer::Create(stream, output_profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, + stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } cudaDataType_t cuda_in_type = AsCudaDataType(type_a); @@ -799,8 +797,8 @@ absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( blas::DataTypeString(type_a), blas::DataTypeString(type_c))); } } - TF_RETURN_IF_ERROR( - PopulateProfileFromTimer(timer, algorithm, output_profile_result)); + TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm, + output_profile_result)); return absl::OkStatus(); } #endif @@ -813,7 +811,7 @@ absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( batch_count, AsCublasComputeType(computation_type), static_cast(algorithm))); TF_RETURN_IF_ERROR( - PopulateProfileFromTimer(timer, algorithm, output_profile_result)); + PopulateProfileFromTimer(timer.get(), algorithm, output_profile_result)); return absl::OkStatus(); } diff --git a/xla/stream_executor/cuda/cuda_blas_lt.cc b/xla/stream_executor/cuda/cuda_blas_lt.cc index b3099604996c92..a4337dfe60e497 100644 --- a/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -41,11 +41,11 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas.h" #include "xla/stream_executor/cuda/cuda_blas_utils.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/stream.h" #include "xla/types.h" #include "xla/util.h" @@ -406,11 +406,10 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( std::optional workspace, std::optional scratch_allocator, blas::ProfileResult* profile_result = nullptr) const { - std::optional timer = std::nullopt; + std::unique_ptr timer; if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, - gpu::GpuTimer::Create(stream, profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + profile_result->warmup_run_executed())); } void* workspace_addr; diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 2cb21cc1aab572..b7dd5fdd4085c4 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -56,12 +56,12 @@ limitations under the License. #include "xla/stream_executor/data_type.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/plugin_registry.h" @@ -2255,7 +2255,7 @@ absl::StatusOr> CreateBatchNormBackwardWorkspace( // Populates the profile result if not empty. static absl::Status PopulateProfileFromTimer( - std::optional& timer, const dnn::AlgorithmDesc& algorithm, + EventBasedTimer* timer, const dnn::AlgorithmDesc& algorithm, dnn::ProfileResult* profile_result, std::optional scratch_size = std::nullopt) { if (profile_result) { @@ -2306,11 +2306,11 @@ absl::Status CudnnSupport::DoRnnForwardImpl( stream, cudnn, rnn_desc, model_dims, input_desc, workspace_allocator, reserve_space_allocator, is_training, &workspace, &reserve_space)); - std::optional timer = std::nullopt; + std::unique_ptr timer; if (output_profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, - GpuTimer::Create(stream, output_profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, + stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } if (input_desc.is_var_seq_lengths()) { @@ -2408,9 +2408,9 @@ absl::Status CudnnSupport::DoRnnForwardImpl( #endif // CUDNN_VERSION >= 90000 } - if (timer.has_value()) { + if (timer != nullptr) { TF_RETURN_IF_ERROR(PopulateProfileFromTimer( - timer, *rnn_desc.algorithm_config().algorithm(), + timer.get(), *rnn_desc.algorithm_config().algorithm(), output_profile_result)); } @@ -2459,11 +2459,11 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( input_desc, workspace_allocator, nullptr, true, &workspace, nullptr)); - std::optional timer; + std::unique_ptr timer; if (output_profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, - GpuTimer::Create(stream, output_profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, + stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } if (input_desc.is_var_seq_lengths()) { @@ -2602,9 +2602,9 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( #endif // CUDNN_VERSION >= 90000 } - if (timer.has_value()) { + if (timer != nullptr) { TF_RETURN_IF_ERROR(PopulateProfileFromTimer( - timer, *rnn_desc.algorithm_config().algorithm(), + timer.get(), *rnn_desc.algorithm_config().algorithm(), output_profile_result)); } @@ -5587,12 +5587,10 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { ? static_cast(&dbeta) : static_cast(&fbeta); - std::optional timer = std::nullopt; - + std::unique_ptr timer; if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, - GpuTimer::Create(stream, profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + profile_result->warmup_run_executed())); } const auto get_fwd_bugs = [&]() -> absl::Status { @@ -5670,9 +5668,9 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { static_cast(kind_)); } - if (timer.has_value()) { - TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer, algo, profile_result, - scratch_memory.size())); + if (timer != nullptr) { + TF_RETURN_IF_ERROR(PopulateProfileFromTimer( + timer.get(), algo, profile_result, scratch_memory.size())); } return absl::OkStatus(); @@ -6034,21 +6032,20 @@ class CudnnExecutionPlanRunner << "\nWorkspace size in bytes: " << workspace_size << "\nVariantPack: " << variantPack.describe(); - std::optional timer = std::nullopt; + std::unique_ptr timer; if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, - GpuTimer::Create(stream, profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + profile_result->warmup_run_executed())); } cudnnStatus_t status = cudnnBackendExecute( cudnn.handle(), plan_.get_raw_desc(), variantPack.get_raw_desc()); RETURN_IF_CUDNN_ERROR(status); - if (timer.has_value()) { + if (timer != nullptr) { TF_ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc()); - TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer, desc, profile_result, - scratch_memory.size())); + TF_RETURN_IF_ERROR(PopulateProfileFromTimer( + timer.get(), desc, profile_result, scratch_memory.size())); VLOG(4) << "cudnn op with plan " << plan_.getTag() << ", workspace_size=" << workspace_size << " -> " @@ -6615,12 +6612,11 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { } auto algo = MakeAlgorithmDesc(); - std::optional timer = std::nullopt; + std::unique_ptr timer; if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, - GpuTimer::Create(stream, profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + profile_result->warmup_run_executed())); } auto side_input_data_ptr = (side_input_scale_ == 0) ? output_data.opaque() @@ -6675,9 +6671,9 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { } RETURN_IF_CUDNN_ERROR(status); - if (timer.has_value()) { - TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer, algo, profile_result, - scratch_memory.size())); + if (timer != nullptr) { + TF_RETURN_IF_ERROR(PopulateProfileFromTimer( + timer.get(), algo, profile_result, scratch_memory.size())); VLOG(4) << "conv with algorithm " << ToConvForwardAlgo(algo) << ", tensor_ops_enabled=" << tensor_ops_enabled_ << ", workspace_size=" << scratch_memory.size() << " -> " diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 456b63a07168e3..fc6af2700c1f7f 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -292,6 +292,7 @@ cc_library( ":rocm_platform_id", "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor:plugin_registry", "//xla/stream_executor:stream_executor_h", @@ -300,7 +301,6 @@ cc_library( "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", "//xla/tsl/util:determinism_hdr_lib", @@ -392,12 +392,12 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:dnn", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:plugin_registry", "//xla/stream_executor/gpu:gpu_activation", "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", @@ -551,12 +551,12 @@ cc_library( "//xla:util", "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor/gpu:gpu_activation", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", "@com_google_absl//absl/status", diff --git a/xla/stream_executor/rocm/hip_blas_lt.cc b/xla/stream_executor/rocm/hip_blas_lt.cc index a1a010ecf12e7c..ce38712b8b56f5 100644 --- a/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/xla/stream_executor/rocm/hip_blas_lt.cc @@ -24,10 +24,10 @@ limitations under the License. #include "xla/util.h" #if TF_HIPBLASLT +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/rocm/hip_blas_lt.h" #include "xla/stream_executor/rocm/rocm_blas.h" #include "xla/stream_executor/scratch_allocator.h" @@ -395,12 +395,11 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( blas_lt_ref_.parent_->RecordApiTrace(StreamExecutor::GemmCallTrace{ StreamExecutor::GemmCallTrace::GemmType::kBlasLt, 0, a.size(), b.size()}); - std::optional timer = std::nullopt; + std::unique_ptr timer; if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, - gpu::GpuTimer::Create(stream, profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + profile_result->warmup_run_executed())); } void* workspace_addr = nullptr; diff --git a/xla/stream_executor/rocm/rocm_blas.cc b/xla/stream_executor/rocm/rocm_blas.cc index 4cc979d2f93f9a..7f08f46342b0e4 100644 --- a/xla/stream_executor/rocm/rocm_blas.cc +++ b/xla/stream_executor/rocm/rocm_blas.cc @@ -28,11 +28,11 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "rocm/rocm_config.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/platform/dso_loader.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/platform/port.h" @@ -329,7 +329,7 @@ uint32_t GemmFloat16Flags(blas::DataType dtype, blas::CallContext context, } absl::Status PopulateProfileFromTimer( - std::optional &timer, blas::AlgorithmType algorithm, + EventBasedTimer *timer, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { if (output_profile_result) { TF_ASSIGN_OR_RETURN(absl::Duration duration, timer->GetElapsedDuration()); @@ -544,10 +544,10 @@ absl::Status ROCMBlas::DoBlasGemmWithAlgorithm( "datatypes for the inputs a (%d) and b (%d) are unsupported", static_cast(type_a), static_cast(type_b))); } - std::optional timer = std::nullopt; + std::unique_ptr timer; if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, GpuTimer::Create(stream, profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + profile_result->warmup_run_executed())); } // fall back to the default implementation @@ -585,7 +585,7 @@ absl::Status ROCMBlas::DoBlasGemmWithAlgorithm( algorithm, GemmFloat16Flags(type_a, context, use_hgemm_alt_impl_))); } TF_RETURN_IF_ERROR( - PopulateProfileFromTimer(timer, algorithm, profile_result)); + PopulateProfileFromTimer(timer.get(), algorithm, profile_result)); return absl::OkStatus(); } @@ -605,10 +605,10 @@ absl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( "datatypes for the inputs a (%d) and b (%d) are unsupported", static_cast(type_a), static_cast(type_b))); } - std::optional timer = std::nullopt; + std::unique_ptr timer; if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN( - timer, GpuTimer::Create(stream, profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + profile_result->warmup_run_executed())); } // fall back to the default implementation @@ -648,7 +648,7 @@ absl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( GemmFloat16Flags(type_a, context, use_hgemm_alt_impl_))); } TF_RETURN_IF_ERROR( - PopulateProfileFromTimer(timer, algorithm, profile_result)); + PopulateProfileFromTimer(timer.get(), algorithm, profile_result)); return absl::OkStatus(); } diff --git a/xla/stream_executor/rocm/rocm_dnn.cc b/xla/stream_executor/rocm/rocm_dnn.cc index 5b8f22b32e4fd8..4c80d3142856cd 100644 --- a/xla/stream_executor/rocm/rocm_dnn.cc +++ b/xla/stream_executor/rocm/rocm_dnn.cc @@ -30,11 +30,11 @@ limitations under the License. #include "rocm/include/miopen/miopen.h" #include "rocm/rocm_config.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/platform/dso_loader.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/plugin_registry.h" @@ -79,7 +79,7 @@ namespace gpu { // Populates the profile result if not empty. static absl::Status PopulateProfileFromTimer( - std::optional& timer, const dnn::AlgorithmDesc& algorithm, + EventBasedTimer* timer, const dnn::AlgorithmDesc& algorithm, dnn::ProfileResult* profile_result, std::optional scratch_size = std::nullopt) { if (profile_result) { @@ -2494,12 +2494,12 @@ absl::Status MIOpenSupport::DoRnnForwardImpl( } const bool is_profiling = output_profile_result != nullptr; - std::optional timer = std::nullopt; + std::unique_ptr timer; if (is_profiling) { - TF_ASSIGN_OR_RETURN( - timer, - GpuTimer::Create(stream, output_profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, + stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } // make the forward call @@ -2544,7 +2544,7 @@ absl::Status MIOpenSupport::DoRnnForwardImpl( if (is_profiling) { TF_RETURN_IF_ERROR(PopulateProfileFromTimer( - timer, *rnn_desc.algorithm_config().algorithm(), + timer.get(), *rnn_desc.algorithm_config().algorithm(), output_profile_result)); } @@ -2626,12 +2626,12 @@ absl::Status MIOpenSupport::DoRnnBackwardImpl( stream->MemZero(input_c_backprop_data, size_data * type_size)); const bool is_profiling = output_profile_result != nullptr; - std::optional timer = std::nullopt; + std::unique_ptr timer; if (is_profiling) { - TF_ASSIGN_OR_RETURN( - timer, - GpuTimer::Create(stream, output_profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, + stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } // make the backward data call @@ -2683,7 +2683,7 @@ absl::Status MIOpenSupport::DoRnnBackwardImpl( if (is_profiling) { TF_RETURN_IF_ERROR(PopulateProfileFromTimer( - timer, *rnn_desc.algorithm_config().algorithm(), + timer.get(), *rnn_desc.algorithm_config().algorithm(), output_profile_result)); } @@ -3326,11 +3326,11 @@ class RocmConvRunner : public dnn::ConvRunner { float beta = 0.0; const bool is_profiling = output_profile_result != nullptr; - std::optional timer = std::nullopt; + std::unique_ptr timer; if (is_profiling) { - TF_ASSIGN_OR_RETURN( - timer, GpuTimer::Create( - stream, output_profile_result->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(timer, + stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } miopenStatus_t status = miopenStatusSuccess; @@ -4850,14 +4850,10 @@ class RocmFusedConvRunner : public dnn::FusedConvRunner { if (activation_desc_.miopen_activation_mode_ != miopenActivationPASTHRU) fusion_plan_.SetActivationForwardArgs(activation_desc_); - std::optional timer; + std::unique_ptr timer; if (profile_result) { - auto timer_or_status = GpuTimer::Create(AsGpuStream(stream)); - if (!timer_or_status.ok()) { - LOG(ERROR) << "Failed to create timer"; - return absl::InternalError("Failed to start timer"); - } - timer.emplace(std::move(*timer_or_status)); + TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + /* use_delay_kernel=*/false)); } miopenStatus_t status; From 8df0da6a60da640cadb731b41db6b32e3b453656 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Sat, 20 Jul 2024 15:26:17 -0700 Subject: [PATCH 040/376] Use Stream::CreateEventBasedTimer instead of old GpuTimer::Create function. PiperOrigin-RevId: 654363342 --- xla/service/gpu/BUILD | 4 +- xla/service/gpu/gpu_executable.cc | 49 +++++++-------------- xla/service/gpu/kernels/BUILD | 2 +- xla/service/gpu/kernels/topk_kernel_test.cc | 8 ++-- 4 files changed, 22 insertions(+), 41 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 404386ce015a55..e25d816e3416e1 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -717,13 +717,13 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:event_based_timer", "//xla/stream_executor:module_spec", "//xla/stream_executor:scoped_module_handle", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/gpu:gpu_activation", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/rocm:rocm_platform_id", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -5673,9 +5673,9 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/gpu:gpu_timer", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", + "//xla/stream_executor/gpu:gpu_timer", "//xla/tests:test_utils", "@tsl//tsl/platform:statusor", ], diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index 8b75187d214c9d..bf9774711fcfd6 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -76,6 +76,11 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/event_based_timer.h" +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/stream_executor/gpu/gpu_activation.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" @@ -85,26 +90,12 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/random.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/scoped_annotation.h" #include "tsl/profiler/lib/traceme.h" -#if TENSORFLOW_USE_ROCM -#include "tsl/platform/random.h" -#endif - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_activation.h" -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" -#else -namespace stream_executor::gpu { -class GpuTimer {}; -} // namespace stream_executor::gpu -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - namespace xla { namespace gpu { @@ -355,10 +346,9 @@ class ResourceRequests : public Thunk::ResourceRequests { absl::flat_hash_map cliques_; }; -absl::Status MaybeSyncAndProfile( - const ServiceExecutableRunOptions* run_options, - std::optional execution_timer, - se::Stream* stream_to_sync); +absl::Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options, + se::EventBasedTimer* execution_timer, + se::Stream* stream_to_sync); absl::Status RendezvousAfterInitialization( const ServiceExecutableRunOptions* run_options); @@ -436,16 +426,13 @@ absl::Status ExecuteThunks( [&] { return absl::StrCat(module_name, ":XLA GPU module"); }, tsl::profiler::TraceMeLevel::kInfo); - std::optional execution_timer; -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + std::unique_ptr execution_timer; if (ExecutionProfile* profile = run_options->run_options().execution_profile(); profile) { - TF_ASSIGN_OR_RETURN( - execution_timer, - se::gpu::GpuTimer::Create(main_stream, profile->warmup_run_executed())); + TF_ASSIGN_OR_RETURN(execution_timer, main_stream->CreateEventBasedTimer( + profile->warmup_run_executed())); } -#endif // Parameters for executing collective operations. TF_ASSIGN_OR_RETURN(Thunk::CollectiveExecuteParams collective_params, @@ -502,7 +489,7 @@ absl::Status ExecuteThunks( TF_RETURN_IF_ERROR(thunk_sequence.ExecuteOnStream(execute_params)); - return MaybeSyncAndProfile(run_options, std::move(execution_timer), + return MaybeSyncAndProfile(run_options, execution_timer.get(), block_host_until_done ? main_stream : nullptr); } @@ -582,23 +569,19 @@ absl::Status RendezvousAfterInitialization( return absl::OkStatus(); } -absl::Status MaybeSyncAndProfile( - const ServiceExecutableRunOptions* run_options, - std::optional execution_timer, - se::Stream* stream_to_sync = nullptr) { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +absl::Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options, + se::EventBasedTimer* execution_timer, + se::Stream* stream_to_sync = nullptr) { // If we're measuring the execution time then it's important to queue the // stop event before triggering any synchronization. if (ExecutionProfile* profile = run_options->run_options().execution_profile(); profile) { - CHECK(execution_timer.has_value()); TF_ASSIGN_OR_RETURN(absl::Duration elapsed, execution_timer->GetElapsedDuration()); profile->set_compute_time_ns( std::max(absl::ToDoubleNanoseconds(elapsed), 1.0)); } -#endif // Make sure kernels are completed before deallocating temporary buffers or // the profiler state. diff --git a/xla/service/gpu/kernels/BUILD b/xla/service/gpu/kernels/BUILD index 9008e4f23e7b4a..9d04094c5fd7cc 100644 --- a/xla/service/gpu/kernels/BUILD +++ b/xla/service/gpu/kernels/BUILD @@ -184,13 +184,13 @@ xla_test( "//xla/stream_executor:device_memory_handle", "//xla/stream_executor:platform_manager", "//xla/stream_executor/gpu:gpu_init", - "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/host:host_platform", "@com_google_absl//absl/log:check", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_benchmark", "@tsl//tsl/platform:test_main", diff --git a/xla/service/gpu/kernels/topk_kernel_test.cc b/xla/service/gpu/kernels/topk_kernel_test.cc index cd37489930cc45..48bcaf07d06cd0 100644 --- a/xla/service/gpu/kernels/topk_kernel_test.cc +++ b/xla/service/gpu/kernels/topk_kernel_test.cc @@ -29,13 +29,13 @@ limitations under the License. #include "absl/time/time.h" #include "xla/stream_executor/device_memory_handle.h" #include "xla/stream_executor/gpu/gpu_init.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/types.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -226,13 +226,11 @@ void BM_SmallTopk(benchmark::State& state) { CHECK_OK(RunTopk(stream.get(), Get(T()), input_buffer.memory(), n, output_values.memory(), output_indices.memory(), k, batch_size)); - auto timer = se::gpu::GpuTimer::Create(stream.get(), - true /* warmup run was executed */); - CHECK_OK(timer.status()); + TF_ASSERT_OK_AND_ASSIGN(auto timer, stream->CreateEventBasedTimer(true)); CHECK_OK(RunTopk(stream.get(), Get(T()), input_buffer.memory(), n, output_values.memory(), output_indices.memory(), k, batch_size)); - auto timer_duration = timer.value().GetElapsedDuration(); + auto timer_duration = timer->GetElapsedDuration(); CHECK_OK(timer_duration.status()); state.SetIterationTime(absl::ToDoubleSeconds(timer_duration.value())); } From 1d9074cac029d49c5efd67ac989686928291ec80 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 20 Jul 2024 21:13:56 -0700 Subject: [PATCH 041/376] Automated Code Change PiperOrigin-RevId: 654415065 --- xla/ffi/BUILD | 1 + xla/ffi/call_frame_test.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/xla/ffi/BUILD b/xla/ffi/BUILD index 96ffa8f8561584..b676e56fc67fdc 100644 --- a/xla/ffi/BUILD +++ b/xla/ffi/BUILD @@ -42,6 +42,7 @@ xla_cc_test( srcs = ["call_frame_test.cc"], deps = [ ":call_frame", + "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/stream_executor:device_memory", "@com_google_absl//absl/status", diff --git a/xla/ffi/call_frame_test.cc b/xla/ffi/call_frame_test.cc index eb331fd8260c4f..2937b53bb5d997 100644 --- a/xla/ffi/call_frame_test.cc +++ b/xla/ffi/call_frame_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/ffi/api/c_api.h" #include "xla/stream_executor/device_memory.h" +#include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" From 37119250b65281c8766e8ee8c421d01d599888da Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 21 Jul 2024 14:53:26 -0700 Subject: [PATCH 042/376] Update base test case for circular pipelining to include matmul operation PiperOrigin-RevId: 654545975 --- xla/tests/collective_ops_test.cc | 47 +++++++++++++++++--------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 2ea6d8ecfcbaab..ab7e7a40898bf1 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -733,37 +733,37 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { results[3])); } -// TODO: b/351064128 - add more complex test cases for circular pipelining XLA_TEST_F(CollectiveOpsTest, CollectivePermute_CircularPipelinePreOptimization) { const absl::string_view kModuleStr = R"( HloModule test while_cond { - param = (u32[], f32[]) parameter(0) + param = (u32[], f32[2,2], f32[2,2]) parameter(0) iter = u32[] get-tuple-element(param), index=0 max_iter = u32[] constant(3) ROOT cmp = pred[] compare(iter, max_iter), direction=LT } while_body { - param = (u32[], f32[]) parameter(0) + param = (u32[], f32[2,2], f32[2,2]) parameter(0) iter = u32[] get-tuple-element(param), index=0 - data = f32[] get-tuple-element(param), index=1 - ten = f32[] constant(10) - sum = f32[] add(data, ten) - cp = f32[] collective-permute(sum), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + data = f32[2,2] get-tuple-element(param), index=1 + weights = f32[2,2] get-tuple-element(param), index=2 + matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, rhs_contracting_dims={0} + cp = f32[2,2] collective-permute(matmul), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} iter_increment = u32[] constant(1) next_iter = u32[] add(iter, iter_increment) - ROOT result = (u32[], f32[]) tuple(next_iter, cp) + ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, cp, weights) } ENTRY test_computation { iter = u32[] constant(0) - data = f32[] parameter(0) - input = (u32[], f32[]) tuple(iter, data) - while_res = (u32[], f32[]) while(input), condition=while_cond, body=while_body - ROOT data_out = f32[] get-tuple-element(while_res), index=1 + data = f32[2,2] parameter(0) + weights = f32[2,2] parameter(1) + input = (u32[], f32[2,2], f32[2,2]) tuple(iter, data, weights) + while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body + ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1 } )"; const int64_t kNumReplicas = 4; @@ -775,10 +775,13 @@ XLA_TEST_F(CollectiveOpsTest, TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(kModuleStr, config)); - constexpr std::array input_values = {3, 1, 0, 4}; + // input for replica i is + // {{i, i}, + // {i, i}} std::vector replica_inputs; - for (float value : input_values) { - replica_inputs.push_back(LiteralUtil::CreateR0(value)); + for (float i = 1; i < kNumReplicas + 1; ++i) { + replica_inputs.push_back({LiteralUtil::CreateR2({{i, i}, {i, i}})}); + replica_inputs.push_back(LiteralUtil::CreateR2({{0, 0}, {0, 1}})); } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, test_runner_.CreateExecutable( @@ -788,17 +791,17 @@ XLA_TEST_F(CollectiveOpsTest, std::vector results, ExecuteReplicated( /*executable_provider=*/[&](int64_t) { return executable.get(); }, - /*argument_count_provider=*/[](int64_t) { return 1; }, + /*argument_count_provider=*/[](int64_t) { return 2; }, /*argument_provider=*/ - [&](int64_t replica, int64_t) -> const Literal* { - return &replica_inputs[replica]; + [&](int64_t replica, int64_t index) -> const Literal* { + return &replica_inputs[replica * 2 + index]; }, kNumReplicas, /*run_hlo_passes=*/true, /*device_assignment=*/nullptr)); - LiteralTestUtil::ExpectR0Equal(31, results[0]); - LiteralTestUtil::ExpectR0Equal(30, results[1]); - LiteralTestUtil::ExpectR0Equal(34, results[2]); - LiteralTestUtil::ExpectR0Equal(33, results[3]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {2, 2}}, results[0]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {3, 3}}, results[1]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {4, 4}}, results[2]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {1, 1}}, results[3]); } XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degenerate) { From 3e8d59080a9fb602c55940be04ec15f643fe7621 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sun, 21 Jul 2024 22:05:21 -0700 Subject: [PATCH 043/376] [xla:cpu] Add sort operation with non-aliasing args and results buffers PiperOrigin-RevId: 654609187 --- xla/service/cpu/cpu_compiler.cc | 7 ++++ xla/service/cpu/ir_emitter2.cc | 3 +- xla/service/cpu/thunk_emitter.cc | 34 +++++++++++----- xla/tests/BUILD | 13 +++++++ xla/tests/sort_test.cc | 67 ++++++++++++++++++++++++++++++++ 5 files changed, 113 insertions(+), 11 deletions(-) create mode 100644 xla/tests/sort_test.cc diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 41013a91ad3735..1454fbc93cdc37 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -1762,6 +1762,13 @@ CpuExecutableAotCompilationResult::LoadExecutable( } } + for (const auto& comparator : ir_emitter2.comparators()) { + if (auto sym = (*jit)->FindCompiledSymbol(comparator.name); !sym) { + return Internal("Failed to find compiled symbol for comparator %s", + comparator.name); + } + } + // Create constant allocations from the buffer assignment. TF_ASSIGN_OR_RETURN( std::vector constants, diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index 86363fdc2a2dd4..be02951ddcfd7e 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -527,7 +527,8 @@ absl::StatusOr IrEmitter2::EmitSortComparator( /*is_top_level_computation=*/true, schedule, /*allow_reassociation=*/false)); - return ComparatorInfo{comparator_function->getName().str()}; + return comparators_.emplace_back( + ComparatorInfo{comparator_function->getName().str()}); } //===----------------------------------------------------------------------===// diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 2d6f7f68a71d8e..86f07e38922552 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -976,25 +976,39 @@ absl::StatusOr ThunkEmitter::EmitSortThunk( TF_ASSIGN_OR_RETURN(auto comparator, ir_emitter_.EmitSortComparator(sort)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(sort)); - if (!absl::c_equal(buffers.arguments, buffers.results)) { + if (buffers.arguments.size() != buffers.results.size()) { return Internal( - "Sort operation expected to be performed inplace and all arguments " - "must alias with results"); + "Sort operation expects the same number of operands and results"); } + ThunkSequence thunks; + std::vector inputs; inputs.reserve(sort->operand_count()); for (size_t i = 0; i < sort->operand_count(); ++i) { - inputs.push_back(SortThunk::Input{ - buffers.arguments[i], - sort->operand(i)->shape(), - }); + const Shape& shape = sort->operand(i)->shape(); + + BufferAllocation::Slice arg = buffers.arguments[i]; + BufferAllocation::Slice result = buffers.results[i]; + + // Copy argument to result if they are not the same buffer. + if (arg != result) { + TF_ASSIGN_OR_RETURN( + thunks.emplace_back(), + CopyThunk::Create(ThunkInfo(instruction), arg, shape, result, shape)); + } + + // Add sort thunk input to sort result buffer inplace. + inputs.push_back(SortThunk::Input{result, shape}); } - return ThunkSequence::Of(ThunkInfo(instruction), inputs, - sort->sort_dimension(), sort->is_stable(), - comparator.name); + TF_ASSIGN_OR_RETURN( + thunks.emplace_back(), + SortThunk::Create(ThunkInfo(instruction), inputs, sort->sort_dimension(), + sort->is_stable(), comparator.name)); + + return thunks; } absl::StatusOr diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 5e07cdb79001ff..d0ab586bc85335 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -1779,6 +1779,19 @@ xla_test( ], ) +xla_test( + name = "sort_test", + srcs = ["sort_test.cc"], + tags = ["test_xla_cpu_thunks"], + deps = [ + ":hlo_test_base", + ":test_macros_header", + ":xla_internal_test_main", + "//xla:error_spec", + "@com_google_googletest//:gtest_main", + ], +) + xla_test( name = "topk_test", srcs = ["topk_test.cc"], diff --git a/xla/tests/sort_test.cc b/xla/tests/sort_test.cc new file mode 100644 index 00000000000000..a6926c2e6e487e --- /dev/null +++ b/xla/tests/sort_test.cc @@ -0,0 +1,67 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "xla/error_spec.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" + +namespace xla { +namespace { + +class SortTest : public HloTestBase {}; + +XLA_TEST_F(SortTest, SortDim0) { + std::string_view hlo_text_module = R"( + HloModule sort + + compare { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT lt = pred[] compare(p0, p1), direction=LT + } + + ENTRY e { + x = f32[32,64] parameter(0) + ROOT sort = f32[32,64] sort(x), dimensions={0}, to_apply=compare + } + )"; + + EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); +} + +XLA_TEST_F(SortTest, SortDim1) { + std::string_view hlo_text_module = R"( + HloModule sort + + compare { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT lt = pred[] compare(p0, p1), direction=LT + } + + ENTRY e { + x = f32[32,64] parameter(0) + ROOT sort = f32[32,64] sort(x), dimensions={1}, to_apply=compare + } + )"; + + EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace xla From 91b9e781c065979d7891beca01c72012bc5cc78a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Mon, 22 Jul 2024 01:57:58 -0700 Subject: [PATCH 044/376] [XLA:CPU] Add support for `transpose` to thunk runtime. In old runtime, if `transpose` is not rewritten by any HLO pass, it is handled by the elemental generator. This commit introduces the same behavior to thunks runtime, also adds test case verifying that case (was missing). Thunk runtime already supports all other ops to which transpose is rewritten, so no further changes are required. Turned on `transpose` tests for thunks runtime. PiperOrigin-RevId: 654656874 --- xla/service/cpu/thunk_emitter.cc | 1 + xla/tests/BUILD | 4 +++- xla/tests/hlo_test_base.cc | 10 +++++----- xla/tests/hlo_test_base.h | 3 ++- xla/tests/transpose_test.cc | 30 ++++++++++++++++++++++++++++-- 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 86f07e38922552..028b30c4e1dd33 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -231,6 +231,7 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kSin: case HloOpcode::kSqrt: case HloOpcode::kSubtract: + case HloOpcode::kTranspose: case HloOpcode::kTan: case HloOpcode::kTanh: case HloOpcode::kXor: diff --git a/xla/tests/BUILD b/xla/tests/BUILD index d0ab586bc85335..230eb6c8772cd0 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -1107,6 +1107,7 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1114,10 +1115,11 @@ xla_test( ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", + "//xla:literal_util", "//xla:reference_util", "//xla:util", - "//xla/client:local_client", "//xla/client:xla_builder", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) diff --git a/xla/tests/hlo_test_base.cc b/xla/tests/hlo_test_base.cc index fe162df793da24..560a486a378987 100644 --- a/xla/tests/hlo_test_base.cc +++ b/xla/tests/hlo_test_base.cc @@ -294,15 +294,15 @@ void HloTestBase::RunAndFilecheckHloModuleGroupRewrite( } absl::StatusOr HloTestBase::Execute( - std::unique_ptr module, absl::Span arguments) { - return runner_->Execute(std::move(module), arguments); + std::unique_ptr module, absl::Span arguments, + bool run_hlo_passes) { + return runner_->Execute(std::move(module), arguments, run_hlo_passes); } Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr module, absl::Span arguments) { - return runner_ - ->Execute(std::move(module), arguments, - /*run_hlo_passes=*/false) + return Execute(std::move(module), arguments, + /*run_hlo_passes=*/false) .value(); } diff --git a/xla/tests/hlo_test_base.h b/xla/tests/hlo_test_base.h index 9ac251ce9b9719..c9f88f237e9b00 100644 --- a/xla/tests/hlo_test_base.h +++ b/xla/tests/hlo_test_base.h @@ -206,7 +206,8 @@ class HloTestBase : public ManifestCheckingTest { // Executes the given module and return the result as a Literal. absl::StatusOr Execute(std::unique_ptr module, - absl::Span arguments); + absl::Span arguments, + bool run_hlo_passes = true); // Same as above, except the module will be executed without running any HLO // passes on it. diff --git a/xla/tests/transpose_test.cc b/xla/tests/transpose_test.cc index b936693ea52abe..e52fccfcc69a6d 100644 --- a/xla/tests/transpose_test.cc +++ b/xla/tests/transpose_test.cc @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include "xla/array2d.h" -#include "xla/client/local_client.h" #include "xla/client/xla_builder.h" +#include "xla/literal_util.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -199,5 +201,29 @@ TEST_F(TransposeTest, TransposeConstant210_DegenerateDim) { TestTransposeConstant({20, 30, 1}, {2, 1, 0}); } +using HloTransposeTest = HloTestBase; + +// Disable HLO passes to verify the default behavior +XLA_TEST_F(HloTransposeTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( + DISABLED_ON_TPU(HloPassesDisabled)))) { + const char* const kModuleStr = R"( + HloModule Transpose + + ENTRY Transpose { + constant = s32[2,3] constant({ { 1, 2, 3 }, { 4, 5, 6 } }) + ROOT transpose = s32[3,2] transpose(constant), dimensions={1,0} + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + TF_ASSERT_OK_AND_ASSIGN( + auto result, Execute(std::move(module), {}, /*run_hlo_passes=*/false)); + Array2D array({{1, 4}, {2, 5}, {3, 6}}); + auto expected = LiteralUtil::CreateR2FromArray2D(array); + + EXPECT_EQ(result, expected); +} + } // namespace } // namespace xla From 37f67be38f7eb76be3113bb8333f6f03d3988e98 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jul 2024 02:07:21 -0700 Subject: [PATCH 045/376] Automated Code Change PiperOrigin-RevId: 654659681 --- xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD | 1 + .../spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD b/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD index 3fcdad891880a7..d500f3cfe633cf 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD +++ b/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD @@ -25,6 +25,7 @@ cc_library( "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@com_google_absl//absl/status", diff --git a/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc b/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc index d5f86b39c610d0..b9c55aebcdbf6b 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" From 50c284e9f9473e7050cc103f456345b43d201278 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 22 Jul 2024 02:19:51 -0700 Subject: [PATCH 046/376] Fix variadic multi-output reductions. In the existing test, the variadic reduction was the last one, so the output index and the root index were the same. If the variadic reduction is the first one, or there is more than one, the current logic is broken. PiperOrigin-RevId: 654662844 --- xla/service/gpu/fusions/reduction_mlir.cc | 14 +++++--- .../gpu/fusions/reduction_mlir_test.cc | 36 +++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/xla/service/gpu/fusions/reduction_mlir.cc b/xla/service/gpu/fusions/reduction_mlir.cc index a8697f259dbc63..8250c408edf619 100644 --- a/xla/service/gpu/fusions/reduction_mlir.cc +++ b/xla/service/gpu/fusions/reduction_mlir.cc @@ -173,10 +173,13 @@ struct MlirReductionFusion::EmitterState { builder(entry_function.getLoc(), entry_function), computation(computations.FindPartitionedComputation( fusion.fused_instructions_computation())) { - int index = 0; - for (const auto& root : owner.analysis_.fusion_roots()) { - fusion_result_index_starts[&root.instruction()] = index; - index += root.shape().IsTuple() ? root.shape().tuple_shapes_size() : 1; + int output_index = 0; + for (const auto& [root_index, root] : + llvm::enumerate(owner.analysis_.fusion_roots())) { + root_indices[&root.instruction()] = root_index; + fusion_result_index_starts[&root.instruction()] = output_index; + output_index += + root.shape().IsTuple() ? root.shape().tuple_shapes_size() : 1; } } @@ -225,6 +228,7 @@ struct MlirReductionFusion::EmitterState { ImplicitLocOpBuilder builder; const mlir_converter::PartitionedComputation& computation; absl::flat_hash_map fusion_result_index_starts; + absl::flat_hash_map root_indices; SmallVector thread_and_block_ids; }; @@ -624,7 +628,7 @@ SmallVector MlirReductionFusion::EvaluateEpilogue( auto values = EmitEpilogue(group_id, state.computations, state.entry_function, results, epilogue_input_indices, b); - int first_root_index = state.OutputIndex(epilogue.roots.front(), 0); + int first_root_index = state.root_indices[epilogue.roots.front()]; auto thread_has_output = mlir_converter::CheckConstraints( *ComputeThreadIdToOutputIndexing(first_root_index, b.getContext()), state.thread_and_block_ids, symbol_values, b); diff --git a/xla/service/gpu/fusions/reduction_mlir_test.cc b/xla/service/gpu/fusions/reduction_mlir_test.cc index 9fac8cc5f34d24..8fc0ba6af05bb7 100644 --- a/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -879,6 +879,42 @@ TEST_F(MlirRowReductionTest, LargeToUnit) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } +TEST_F(MlirRowReductionTest, MOFTwoVariadic) { + // Regression test for a compilation crash with a MOF with two variadic + // reductions. + constexpr auto kHloString = R"( + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + p2 = f32[] parameter(2) + p3 = f32[] parameter(3) + a = f32[] add(p0, p2) + b = f32[] add(p1, p3) + ROOT out = (f32[], f32[]) tuple(a, b) + } + + fused_reduce { + p0 = f32[3,2] parameter(0) + p1 = f32[3,2] parameter(1) + c0 = f32[] constant(0) + iota0 = f32[3,2] iota(), iota_dimension=1 + iota1 = f32[3,2] iota(), iota_dimension=1 + reduce0 = (f32[3], f32[3]) reduce(p0, iota0, c0, c0), dimensions={1}, + to_apply=add + reduce1 = (f32[3], f32[3]) reduce(p1, iota1, c0, c0), dimensions={1}, + to_apply=add + ROOT tuple = ((f32[3], f32[3]), (f32[3], f32[3])) tuple(reduce0, %reduce1) + } + + ENTRY main { + p0 = f32[3,2] parameter(0) + p1 = f32[3,2] parameter(1) + ROOT fusion = ((f32[3], f32[3]), (f32[3], f32[3])) fusion(p0, p1), + kind=kInput, calls=fused_reduce + })"; + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + } // namespace } // namespace gpu } // namespace xla From 3a7d7bc5b8fbb2eb66c4685d9c25178d8407704f Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Mon, 22 Jul 2024 02:34:17 -0700 Subject: [PATCH 047/376] PR #15081: Adds a convenient python binding for obtaining instruction costs. Imported from GitHub PR https://github.com/openxla/xla/pull/15081 The python binding being added in this PR can help quick analysis on the cost of the HLO instructions with easy python code. The binding is tested with a real tensorboard dir: ``` import jax.lib.xla_bridge as xb client = xb.xla_client costs = client.profiler.get_instructions_profile(tensorboard_dir) for name, cost in costs: print(name, cost) ``` works. Copybara import of the project: -- e890d5ca56eaa212d63e2f578140758ffe230923 by Yunlong Liu : in memory python representation of profiled instructions Merging this change closes #15081 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15081 from yliu120:main e890d5ca56eaa212d63e2f578140758ffe230923 PiperOrigin-RevId: 654666728 --- xla/python/profiler.cc | 18 ++++++++++++++++++ xla/python/xla_extension/profiler.pyi | 3 ++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/xla/python/profiler.cc b/xla/python/profiler.cc index aa5e6208bd2422..d413fc5892fdd1 100644 --- a/xla/python/profiler.cc +++ b/xla/python/profiler.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep @@ -269,6 +270,23 @@ void BuildProfilerSubmodule(nb::module_& m) { }, nb::arg("tensorboard_dir")); + profiler.def( + "get_instructions_profile", + [](const std::string& tensorboard_dir) + -> std::vector> { + tensorflow::profiler::ProfiledInstructionsProto profile_proto; + xla::ThrowIfError( + xla::ConvertXplaneUnderLogdirToProfiledInstructionsProto( + tensorboard_dir, &profile_proto)); + std::vector> results; + results.reserve(profile_proto.costs().size()); + for (const auto& c : profile_proto.costs()) { + results.emplace_back(c.name(), c.cost_us()); + } + return results; + }, + nb::arg("tensorboard_dir")); + profiler.def("get_fdo_profile", [](nb::bytes xspace, bool as_textproto = false) -> nb::object { std::string out = GetFdoProfile( diff --git a/xla/python/xla_extension/profiler.pyi b/xla/python/xla_extension/profiler.pyi index 5adc9c5111f1ae..3c1fc5bd35f03d 100644 --- a/xla/python/xla_extension/profiler.pyi +++ b/xla/python/xla_extension/profiler.pyi @@ -14,7 +14,7 @@ # ============================================================================== from types import TracebackType -from typing import Any, Optional, Type, Union, List +from typing import Any, Optional, Type, Union, List, Tuple _Status = Any @@ -24,6 +24,7 @@ def start_server(port: int) -> ProfilerServer: ... def register_plugin_profiler(c_api: Any) -> None: ... def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ... +def get_instructins_profile(tensorboard_dir: str) -> List[Tuple[str, float]]: ... def get_fdo_profile( xspace: bytes, as_textproto: bool = ... ) -> Union[bytes, str]: ... From 0b1fdbd11eb8591a6df3451a1092c0d2ee042145 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 22 Jul 2024 03:00:38 -0700 Subject: [PATCH 048/376] Fix indexing maps for broadcasting elementwise. PiperOrigin-RevId: 654672618 --- xla/service/gpu/model/indexing_analysis.cc | 13 ++++++++- .../gpu/model/indexing_analysis_test.cc | 29 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/model/indexing_analysis.cc b/xla/service/gpu/model/indexing_analysis.cc index 18e7b22968532f..89124182909aca 100644 --- a/xla/service/gpu/model/indexing_analysis.cc +++ b/xla/service/gpu/model/indexing_analysis.cc @@ -80,12 +80,23 @@ HloInstructionIndexing CreateUnknownIndexing(int64_t count = 1) { HloInstructionIndexing ComputeOutputToInputCwiseOpIndexing( const HloInstruction* instr, MLIRContext* mlir_context) { IndexingMap identity_map = CreateIdentityMap(instr->shape(), mlir_context); + IndexingMap unit_map( + mlir::AffineMap::get(identity_map.GetAffineMap().getNumDims(), + /*symbolCount=*/0, mlir_context), + identity_map.GetDimVars(), /*range_vars=*/{}, /*rt_vars=*/{}); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(instr->operand_count()); int64_t operand_count = instr->operand_count(); for (int64_t operand_id = 0; operand_id < operand_count; ++operand_id) { - instr_indexing.indexing_maps[operand_id].insert(identity_map); + // Select allows implicit broadcasting in the predicate. We just handle it + // generically here. + auto* operand = instr->operand(operand_id); + if (operand->shape().rank() == 0 && instr->shape().rank() > 0) { + instr_indexing.indexing_maps[operand_id].insert(unit_map); + } else { + instr_indexing.indexing_maps[operand_id].insert(identity_map); + } } return instr_indexing; } diff --git a/xla/service/gpu/model/indexing_analysis_test.cc b/xla/service/gpu/model/indexing_analysis_test.cc index 709a980e03fcea..122461fd255cbc 100644 --- a/xla/service/gpu/model/indexing_analysis_test.cc +++ b/xla/service/gpu/model/indexing_analysis_test.cc @@ -2654,6 +2654,35 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing_NoEpilogue) { )")); } +TEST_F(IndexingAnalysisTest, BroadcastingElementwise) { + auto root = ParseAndGetRoot(R"( + HloModule m + ENTRY e { + p0 = pred[] parameter(0) + p1 = f32[1000, 1000] parameter(1) + p2 = f32[1000, 1000] parameter(2) + ROOT select = f32[1000, 1000] select(p0, p1, p2) + } + )"); + auto input_indexing = GetOutputToInputIndexing(root); + + EXPECT_THAT(GetOutputToInputIndexing(root).ToString(), MatchIndexingString(R"( + operand id = 0 + (d0, d1) -> () + domain: + d0 in [0, 1000) + d1 in [0, 1000) + operand id = 1 (d0, d1) -> (d0, d1) + domain: + d0 in [0, 1000) + d1 in [0, 1000) + operand id = 2 (d0, d1) -> (d0, d1) + domain: + d0 in [0, 1000) + d1 in [0, 1000) + )")); +} + } // namespace } // namespace gpu } // namespace xla From eb12203e8c876056762a23fd5a397fbe6e2fc8c4 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 22 Jul 2024 03:02:54 -0700 Subject: [PATCH 049/376] [XLA:GPU] Support RTVars in symbolic tile analysis. PiperOrigin-RevId: 654673152 --- .../gpu/model/symbolic_tile_analysis.cc | 21 +++--- .../gpu/model/symbolic_tile_analysis_test.cc | 69 +++++++++++++++++++ 2 files changed, 82 insertions(+), 8 deletions(-) diff --git a/xla/service/gpu/model/symbolic_tile_analysis.cc b/xla/service/gpu/model/symbolic_tile_analysis.cc index d72590debe3d2e..ccf2fd642fad9f 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -135,18 +135,23 @@ absl::StatusOr ComputeTileOffsetIndexing( } std::vector symbol_lower_bounds( - tile_offset_indexing.GetSymbolCount(), + tile_offset_indexing.GetRangeVarsCount(), mlir::getAffineConstantExpr(0, mlir_context)); + symbol_lower_bounds.reserve(tile_offset_indexing.GetSymbolCount()); + for (int i = 0; i < tile_offset_indexing.GetRTVarsCount(); ++i) { + symbol_lower_bounds.push_back(mlir::getAffineSymbolExpr(i, mlir_context)); + } mlir::AffineMap simplified_affine_map = tile_offset_indexing.GetAffineMap().replaceDimsAndSymbols( - /*dimReplacements=*/{}, symbol_lower_bounds, - tile_offset_indexing.GetDimVarsCount(), - /*numResultSyms=*/tile_offset_indexing.GetRangeVarsCount()); - - IndexingMap simplified_indexing_map = IndexingMap{ - simplified_affine_map, tile_offset_indexing.GetDimVars(), - tile_offset_indexing.GetRangeVars(), tile_offset_indexing.GetRTVars()}; + /*dimReplacements=*/{}, + /*symReplacements=*/symbol_lower_bounds, + /*numResultDims=*/tile_offset_indexing.GetDimVarsCount(), + /*numResultSyms=*/tile_offset_indexing.GetRTVarsCount()); + + IndexingMap simplified_indexing_map = + IndexingMap{simplified_affine_map, tile_offset_indexing.GetDimVars(), + /*range_vars=*/{}, tile_offset_indexing.GetRTVars()}; simplified_indexing_map.Simplify(); simplified_indexing_map.RescaleSymbols(); diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc index e73f04d980e01e..a0e964ff77f4f4 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -795,6 +795,75 @@ ENTRY main { )")); } +TEST_F(SymbolicTileAnalysisTest, CanComputeTiledHloInstructionsWithRTVars) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +max_computation { + param_0 = s32[] parameter(0) + param_1 = s32[] parameter(1) + ROOT maximum = s32[] maximum(param_0, param_1) +} + +fused_computation { + src = s32[2,2,258] parameter(0) + of1 = s32[] parameter(1) + of2 = s32[] parameter(2) + of3 = s32[] parameter(3) + ds = s32[1,2,32] dynamic-slice(s32[2,2,258] src, s32[] of1, s32[] of2, s32[] of3), + dynamic_slice_sizes={1, 2, 32} + c0 = s32[] constant(0) + ROOT reduce = s32[1,2] reduce(ds, c0), dimensions={2}, to_apply=max_computation +} + +ENTRY main { + param_0 = s32[2,2,258] parameter(0) + param_1 = s32[] parameter(1) + param_2 = s32[] parameter(2) + param_3 = s32[] parameter(3) + ROOT fusion = s32[1,2] fusion(param_0, param_1, param_2, param_3), kind=kLoop, calls=fused_computation +} +)")); + + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + TF_ASSERT_OK_AND_ASSIGN( + TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions(/*tile_parameters=*/{1, 1})); + + const TiledHloInstruction* dynamic_slice = + tiled_hlo_computation.GetRoot()->operand(0); + const TiledHloInstruction* param_0_tile = dynamic_slice->operand(0); + + EXPECT_THAT(*dynamic_slice, MatchTiledHloInstruction( + /*tile_sizes=*/{1, 1, 32}, + /*tile_strides=*/{0, 1, 1}, + /*tile_offsets_indexing=*/R"( + (d0, d1) -> (0, d1, 0) + domain: + d0 in [0, 1) + d1 in [0, 2) + )")); + + EXPECT_THAT(*param_0_tile, MatchTiledHloInstruction( + /*tile_sizes=*/{1, 1, 32}, + /*tile_strides=*/{0, 1, 1}, + /*tile_offsets_indexing=*/R"( + (d0, d1)[s0, s1] -> (s0, d1, s1) + domain: + d0 in [0, 1) + d1 in [0, 2) + s0 in [0, 2) + hlo: %of1 = s32[] parameter(1) + (d0, d1, d2) -> () + s1 in [0, 227) + hlo: %of3 = s32[] parameter(3) + (d0, d1, d2) -> () + )")); +} + } // namespace } // namespace gpu } // namespace xla From fecc1dde0cada703d3e48343918d0cd465b3d3e6 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Mon, 22 Jul 2024 03:10:08 -0700 Subject: [PATCH 050/376] Add IndexingMapAttr to XLA GPU Dialect I will create an mlir test for the parser/printer & add it to ApplyIndexingOp in subsequent cls. PiperOrigin-RevId: 654674947 --- xla/service/gpu/fusions/mlir/ir/BUILD | 35 ++++++++-- .../gpu/fusions/mlir/ir/xla_gpu_attrs.cc | 54 +++++++++++++++ .../gpu/fusions/mlir/ir/xla_gpu_attrs.h | 52 ++++++++++++++ .../gpu/fusions/mlir/ir/xla_gpu_attrs.td | 68 +++++++++++++++++++ .../gpu/fusions/mlir/ir/xla_gpu_dialect.td | 1 + .../gpu/fusions/mlir/ir/xla_gpu_ops.cc | 10 +++ xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h | 9 ++- xla/service/gpu/model/indexing_map.h | 8 +++ 8 files changed, 229 insertions(+), 8 deletions(-) create mode 100644 xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc create mode 100644 xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h create mode 100644 xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td diff --git a/xla/service/gpu/fusions/mlir/ir/BUILD b/xla/service/gpu/fusions/mlir/ir/BUILD index 7cd10ce19d62b2..e3db7145981a16 100644 --- a/xla/service/gpu/fusions/mlir/ir/BUILD +++ b/xla/service/gpu/fusions/mlir/ir/BUILD @@ -14,7 +14,7 @@ package_group( ) td_library( - name = "xla_gpu_ops_td_files", + name = "xla_gpu_td_files", srcs = glob(["*.td"]), includes = ["."], deps = [ @@ -40,7 +40,7 @@ gentbl_cc_library( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_gpu_dialect.td", - deps = [":xla_gpu_ops_td_files"], + deps = [":xla_gpu_td_files"], ) gentbl_cc_library( @@ -58,14 +58,39 @@ gentbl_cc_library( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_gpu_ops.td", - deps = [":xla_gpu_ops_td_files"], + deps = [":xla_gpu_td_files"], +) + +gentbl_cc_library( + name = "xla_gpu_attrs_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + ["-gen-attrdef-decls"], + "xla_gpu_attrs.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "xla_gpu_attrs.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_gpu_attrs.td", + deps = [":xla_gpu_td_files"], ) cc_library( name = "xla_gpu", - srcs = ["xla_gpu_ops.cc"], - hdrs = ["xla_gpu_ops.h"], + srcs = [ + "xla_gpu_attrs.cc", + "xla_gpu_ops.cc", + ], + hdrs = [ + "xla_gpu_attrs.h", + "xla_gpu_ops.h", + ], deps = [ + ":xla_gpu_attrs_inc_gen", ":xla_gpu_dialect_inc_gen", ":xla_gpu_ops_inc_gen", "//xla/service/gpu/model:indexing_analysis", diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc new file mode 100644 index 00000000000000..6feb0eadfaade9 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc @@ -0,0 +1,54 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" + +#include + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LLVM.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +void PrintDimVars(mlir::AsmPrinter& p, llvm::ArrayRef dim_vars) {} + +mlir::FailureOr> ParseDimVars( + mlir::AsmParser& parser) { + return mlir::failure(); +} + +void PrintRangeVars(mlir::AsmPrinter& p, llvm::ArrayRef range_vars) {} + +mlir::FailureOr> ParseRangeVars( + mlir::AsmParser& parser) { + return mlir::failure(); +} + +void PrintConstraints( + mlir::AsmPrinter& p, + mlir::ArrayRef> + range_vars) {} + +mlir::FailureOr< + llvm::SmallVector>> +ParseConstraints(mlir::AsmParser& parser) { + return mlir::failure(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h new file mode 100644 index 00000000000000..bd6cf0424b1db7 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h @@ -0,0 +1,52 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_ +#define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_ + +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LLVM.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +// Custom printer to print an array of DimVar. +void PrintDimVars(mlir::AsmPrinter& p, mlir::ArrayRef dim_vars); + +// Custom parser to parse an array of DimVar. +mlir::FailureOr> ParseDimVars( + mlir::AsmParser& parser); + +// Custom printer to print an array of RangeVar. +void PrintRangeVars(mlir::AsmPrinter& p, mlir::ArrayRef range_vars); + +// Custom parser to parse an array of RangeVar. +mlir::FailureOr> ParseRangeVars( + mlir::AsmParser& parser); + +// Custom printer to print constraints. +void PrintConstraints( + mlir::AsmPrinter& p, + mlir::ArrayRef<::std::pair<::mlir::AffineExpr, Interval>> range_vars); + +// Custom parser to parse constraints. +mlir::FailureOr>> +ParseConstraints(mlir::AsmParser& parser); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_ diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td new file mode 100644 index 00000000000000..51910d27c5a3cc --- /dev/null +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS +#define XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS + +include "mlir/IR/AttrTypeBase.td" +include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td" + +class XLAGPU_Attr traits = []> : + AttrDef { +} + +def XLAGPU_AffineMapParameter : + AttrOrTypeParameter<"::mlir::AffineMap", ""> { +} + +def XLAGPU_DimVarsParameter : ArrayRefParameter<"::xla::gpu::DimVar", + "DimVarArray"> { + let parser = "ParseDimVars($_parser)"; + let printer = "PrintDimVars($_printer, $_self)"; +} + +def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar", + "RangeVarArray"> { + let parser = "ParseRangeVars($_parser)"; + let printer = "PrintRangeVars($_printer, $_self)"; +} + +def XLAGPU_ConstraintsParameter : + ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>", + "ContraintsArray"> { + let parser = "ParseConstraints($_parser)"; + let printer = "PrintConstraints($_printer, $_self)"; +} + +def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { + let summary = "An Attribute representing an indexing map."; + let mnemonic = "indexing_map"; + let description = [{This attribute stores an indexing map. See + https://openxla.org/xla/indexing for more details. + }]; + let parameters = (ins XLAGPU_AffineMapParameter:$map, + XLAGPU_DimVarsParameter:$dim_vars, + XLAGPU_RangeVarsParameter:$range_vars, + XLAGPU_ConstraintsParameter:$constraints); + + let assemblyFormat = [{ + `<` `map` `=` $map `,` + `dim_vars` `=` $dim_vars`,` + `range_vars` `=` $range_vars `,` + `constraints` `=` $constraints `>` + }]; +} + +#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_ATTRS diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td b/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td index bb599174a2f3d5..4400747923cb6e 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td @@ -26,6 +26,7 @@ def XlaGpuDialect : Dialect { }]; let cppNamespace = "::xla::gpu"; + let useDefaultAttributePrinterParser = 1; } #endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_DIALECT diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc index b9cdcb8efcc908..8bdc4c95985c9e 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/MLIRContext.h" // IWYU pragma: keep #include "mlir/IR/OpDefinition.h" @@ -47,6 +48,10 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc.inc" #include "xla/service/gpu/model/indexing_map.h" +#define GET_ATTRDEF_CLASSES +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" +#undef GET_ATTRDEF_CLASSES + namespace xla { namespace gpu { namespace { @@ -151,6 +156,11 @@ void XlaGpuDialect::initialize() { #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc" #undef GET_OP_LIST >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" + >(); +#undef GET_ATTRDEF_LIST addInterfaces(); } diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h index 6f46ef4ce7940b..02604c1ea99db7 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h @@ -24,13 +24,16 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // IWYU pragma: keep #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/Interfaces/CallInterfaces.h" // IWYU pragma: keep -#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma : keep -#include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma : keep -#include "xla/service/gpu/model/indexing_map.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep +#include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" // IWYU pragma: keep #define GET_OP_CLASSES #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc" #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h.inc" #undef GET_OP_CLASSES +#define GET_ATTRDEF_CLASSES +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc" +#undef GET_ATTRDEF_CLASSES #endif // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ diff --git a/xla/service/gpu/model/indexing_map.h b/xla/service/gpu/model/indexing_map.h index d13f698a8430c9..e38e68c1e76179 100644 --- a/xla/service/gpu/model/indexing_map.h +++ b/xla/service/gpu/model/indexing_map.h @@ -197,6 +197,10 @@ H AbslHashValue(H h, const DimVar& dimension) { return H::combine(std::move(h), dimension.bounds); } +inline size_t hash_value(const DimVar& dim_var) { + return llvm::hash_combine(dim_var.bounds); +} + // RangeSymbol variable represents a range of values, e.g. to compute a single // element of the reduction's result we need a range of values from the input // tensor. RangeSymbol variables correspond to the front portion of the @@ -214,6 +218,10 @@ H AbslHashValue(H h, const RangeVar& range_var) { return H::combine(std::move(h), range_var.range); } +inline size_t hash_value(const RangeVar& range_var) { + return llvm::hash_combine(range_var.range); +} + // RTSymbol variable represents a runtime symbol, e.g. a dynamic offset in // HLO dynamic-update-slice op. RTSymbol variables correspond to the back // portion of the symbols in `affine_map_`. From 0d7aab4298530a472e571dc82d666d2dfc8318d7 Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Mon, 22 Jul 2024 03:59:37 -0700 Subject: [PATCH 051/376] PR #15153: [XLA:CPU][oneDNN] Revert PR 13527 causing accuracy issue Imported from GitHub PR https://github.com/openxla/xla/pull/15153 Revert "PR #13527: [XLA:CPU][oneDNN] Enable mm-bias-add fusion" This reverts commit 4ac9fdaae77256df4531eca38683d70777abf434 as we are seeing accuracy issues with some workloads. Copybara import of the project: -- 003d71c185e8b0e2f3f486f80839a2c20a7410dd by Kanvi Khanna : Revert "PR #13527: [XLA:CPU][oneDNN] Enable mm-bias-add fusion" This reverts commit 4ac9fdaae77256df4531eca38683d70777abf434. Merging this change closes #15153 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15153 from Intel-tensorflow:kanvi/accuracy-fix 003d71c185e8b0e2f3f486f80839a2c20a7410dd PiperOrigin-RevId: 654686110 --- xla/service/cpu/onednn_matmul_rewriter.cc | 68 ++++---------- xla/tests/onednn_matmul_test.cc | 106 ---------------------- 2 files changed, 17 insertions(+), 157 deletions(-) diff --git a/xla/service/cpu/onednn_matmul_rewriter.cc b/xla/service/cpu/onednn_matmul_rewriter.cc index 779c6cf6a2fe12..45c6bc17a41b20 100644 --- a/xla/service/cpu/onednn_matmul_rewriter.cc +++ b/xla/service/cpu/onednn_matmul_rewriter.cc @@ -324,20 +324,6 @@ absl::StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, return new_shape; }; -// Compute new shape for the binary operand when dot's outer dims -// are flattened/unflattened with respect to the binary operand dims. -// Adjusting the operand shape to the dot's shape enables fusion in oneDNN. -absl::StatusOr AdjustBinaryOperandShape( - const HloInstruction* operand_instr, const Shape& dot_shape) { - if (ShapeUtil::ElementsIn(operand_instr->shape()) != - ShapeUtil::ElementsIn(dot_shape)) { - return absl::CancelledError( - "Number of elements in operand and dot instruction do not match."); - } - Shape new_shape = dot_shape; - return new_shape; -}; - inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* dot) { // Check if the operand's shape is compatible with matmul for fusion. // An operand is fusable if @@ -367,19 +353,11 @@ inline auto OptionalConvertAndBitcast(HloInstruction** optional_convert, // 1. pattern-root -> bf16/f16-to-fp32 convert -> bitcast // 2. pattern-root -> bf16/f16-to-fp32 convert // 3. pattern-root -> bitcast - // 4. pattern-root -> bitcast -> bf16-to-fp32 convert - // 5. pattern-root + // 4. pattern-root auto common = m::AnyOf( - pu::SupportedConvert(optional_convert, - std::move(pattern).WithOneUser()) - .WithElementType(PrimitiveType::F32), - std::move(pattern).WithOneUser(), - pu::SupportedConvert( - optional_convert, - BitcastWithReshapeSemantics( - optional_bitcast, std::move(pattern).WithOneUser())) - .WithElementType(PrimitiveType::F32)) - .WithOneUser(); + pu::SupportedConvert(optional_convert, std::move(pattern).WithOneUser()) + .WithElementType(PrimitiveType::F32), + std::move(pattern).WithOneUser()); return m::AnyOf( BitcastWithReshapeSemantics(optional_bitcast, common), common); } @@ -521,6 +499,19 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { if (Match(instr, pattern)) { if (!IsSupportedType(dot->shape().element_type())) return absl::OkStatus(); + // TODO(intel-tf): Remove the condition below when the fusion Dot + + // Add(bias) + Add(e.g., residual) is enabled. + if (!dot->backend_config() + ->mutable_onednn_matmul_config() + ->mutable_fusions() + ->ops() + .empty() && + dot->backend_config() + ->mutable_onednn_matmul_config() + ->mutable_fusions() + ->ops(0) == OneDnnFusionConfig::BIAS) { + return absl::OkStatus(); + } std::vector new_operands; for (auto operand : dot->operands()) { new_operands.push_back(operand); @@ -559,31 +550,6 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } } - // For cases where the dot is followed by a reshape, the binary operands - // shape can be adjusted, making sure the number of elements match, to - // enable the fusion. For example: - // dot = f32[6304,3072] dot(...) - // reshape = f32[32,197,3072] reshape(dot) - // constant = f32[32,197,3072] constant(..) - // add = f32[32,197,3072] add(reshape, constant) - // can become - // dot = f32[6304,3072] dot(...) - // constant = f32[32,197,3072] constant(..) - // reshape1 = f32[6304,3072] reshape(constant) - // add = f32[6304,3072] add(dot, reshape1) - // and be replaced with the fusion - // fused = f32[6304,3072] custom-call(..) - // bitcast = f32[32,197,3072] bitcast(fused) - // clang-format on - auto addend_dims = addend->shape().dimensions(); - auto dot_dims = dot->shape().dimensions(); - if (optional_dot_bitcast && addend_dims.size() != dot_dims.size()) { - auto new_addend_shape = AdjustBinaryOperandShape(addend, dot->shape()); - if (new_addend_shape.ok()) { - addend = addend->AddInstruction( - HloInstruction::CreateBitcast(new_addend_shape.value(), addend)); - } - } // Validate addend for fusion. if (IsSupportedType(addend->shape().element_type()) && diff --git a/xla/tests/onednn_matmul_test.cc b/xla/tests/onednn_matmul_test.cc index 6cf2f77e38a18f..389716c4ddef95 100644 --- a/xla/tests/onednn_matmul_test.cc +++ b/xla/tests/onednn_matmul_test.cc @@ -131,15 +131,6 @@ class MatmulTest : public HloTestBase { ; CHECK-DAG: } ; CHECK: } )"; - const char* fused_matmul_bias_add_str_ = R"( - ; CHECK: custom_call_target="__onednn$matmul", - ; CHECK: backend_config={ - ; CHECK-DAG: "outer_dimension_partitions":[], - ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BIAS","BINARY_ADD"] - ; CHECK-DAG: } - ; CHECK: } - )"; }; TEST_F(MatmulTest, SimpleTestF32) { @@ -1548,103 +1539,6 @@ TEST_F(MatmulTest, ConsecutiveBinaryAdd) { EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); } -TEST_F(MatmulTest, SimpleTestF32WithBiasAndAddFusion) { - const char* matmul_module_str = R"( - HloModule matmul.bias.add.test.f32 - ENTRY matmul.bias.add.test.f32 { - arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} - arg0.2 = f32[32,32,30,40]parameter(1), parameter_replication={false} - dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - const.0 = f32[40] constant(15) - bcast.1 = f32[32,32,40,40] broadcast(const.0), dimensions={3} - add.0 = f32[32,32,40,40] add(dot.7,bcast.1) - const.1 = f32[32,32,40,40] constant(0.65) - add.1 = f32[32,32,40,40] add(add.0, const.1) - tuple.12 = (f32[32,32,40,40]) tuple(add.1) - ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 - })"; - - EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); - MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_add_str_); -} - -TEST_F(MatmulTest, SimpleTestF32WithBiasAndAddFusion2) { - const char* matmul_module_str = R"( - HloModule matmul.test.f32 - ENTRY matmul.test.f32 { - arg.0 = f32[6304,768] parameter(0), parameter_replication={false} - arg.1 = f32[768,3072] parameter(1), parameter_replication={false} - dot.378 = f32[6304,3072] dot(arg.0, arg.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} - reshape.11 = f32[32,197,3072] reshape(dot.378) - constant.381 = f32[3072] constant(0.3) - broadcast.382 = f32[32,197,3072] broadcast(constant.381), dimensions={2} - add.0 = f32[32,197,3072] add(reshape.11, broadcast.382) - const.1 = f32[32,197,3072] constant(0.65) - add.1 = f32[32,197,3072] add(add.0, const.1) - ROOT out = f32[6304,3072] reshape(add.1) - })"; - - EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); - MatchOptimizedHlo(matmul_module_str, fused_matmul_bias_add_str_); -} - -TEST_F(MatmulTest, SimpleTestF32WithAddFusion) { - const char* matmul_module_str = R"( - HloModule matmul.test.f32 - ENTRY matmul.test.f32 { - arg.0 = f32[6304,768] parameter(0), parameter_replication={false} - arg.1 = f32[768,3072] parameter(1), parameter_replication={false} - dot.378 = f32[6304,3072] dot(arg.0, arg.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} - reshape.11 = f32[32,197,3072] reshape(dot.378) - const.1 = f32[32,197,3072] constant(0.65) - add.1 = f32[32,197,3072] add(reshape.11, const.1) - ROOT out = f32[6304,3072] reshape(add.1) - })"; - - EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); - MatchOptimizedHlo(matmul_module_str, - R"( - ; CHECK: custom_call_target="__onednn$matmul", - ; CHECK: backend_config={ - ; CHECK-DAG: "outer_dimension_partitions":[], - ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BINARY_ADD"] - ; CHECK-DAG: } - ; CHECK: } - )"); -} - -TEST_F(MatmulTest, SimpleTestF32WithAddFusion_2) { - // Only the first Bias should get fused as Bias - const char* matmul_module_str = R"( - HloModule matmul.add.test.f32 - ENTRY matmul.add.test.f32 { - arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} - arg0.2 = f32[32,32,30,40]parameter(1), parameter_replication={false} - dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - const.0 = f32[40] constant(15) - bcast.1 = f32[32,32,40,40] broadcast(const.0), dimensions={3} - add.0 = f32[32,32,40,40] add(dot.7,bcast.1) - const.1 = f32[40] constant(0.65) - bcast.2 = f32[32,32,40,40] broadcast(const.1), dimensions={3} - add.1 = f32[32,32,40,40] add(add.0, bcast.2) - tuple.12 = (f32[32,32,40,40]) tuple(add.1) - ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 - })"; - - EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); - MatchOptimizedHlo(matmul_module_str, - R"( - ; CHECK: custom_call_target="__onednn$matmul", - ; CHECK: backend_config={ - ; CHECK-DAG: "outer_dimension_partitions":[], - ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-DAG: "fused_ops":["BIAS","BINARY_ADD"] - ; CHECK-DAG: } - ; CHECK: } - )"); -} - TEST_F(MatmulTest, BroadcastedAddAfterFusion) { const char* matmul_module_str = R"( HloModule matmul.nonscalar.test.1 From 45f3de680fb195d39b90050338056db39254ec47 Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Mon, 22 Jul 2024 04:09:55 -0700 Subject: [PATCH 052/376] PR #15029: Skip emitting Triton kernel when deserializing from cache Imported from GitHub PR https://github.com/openxla/xla/pull/15029 This change avoids running final part of the Triton kernel emission when deserializing form cache. This can make a 0.5-1s difference on larger Pallas kernels (we see ~600ms/2x improvement in deserialization time of a step/update function with Pallas attention kernel). Copybara import of the project: -- 700c1704ff124042185a5b3e8dba82b5eca6bc34 by Jaroslav Sevcik : Skip emitting Triton kernel when deserializing -- 69505da4cd81a35c390f6301f4a474eb2d0c0c67 by Jaroslav Sevcik : Address reviewer comments Merging this change closes #15029 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15029 from jaro-sevcik:avoid-triton-compilation-on-deserialization 69505da4cd81a35c390f6301f4a474eb2d0c0c67 PiperOrigin-RevId: 654689073 --- xla/service/gpu/BUILD | 2 + xla/service/gpu/fusions/fusion_emitter.cc | 31 +++-- xla/service/gpu/fusions/fusion_emitter.h | 14 ++ .../fusions/triton/triton_fusion_emitter.cc | 47 ++++--- .../fusions/triton/triton_fusion_emitter.h | 5 +- xla/service/gpu/gpu_aot_compilation_test.cc | 126 ++++++++++++++++++ xla/service/gpu/ir_emitter_unnested.cc | 75 ++++++----- 7 files changed, 238 insertions(+), 62 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index e25d816e3416e1..0d8dd38d751cbe 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3627,6 +3627,7 @@ xla_cc_test( "//xla/service:executable", "//xla/service:gpu_plugin", "//xla/service:platform_util", + "//xla/service/gpu/fusions/triton:triton_support", "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", @@ -3634,6 +3635,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/gpu/fusions/fusion_emitter.cc b/xla/service/gpu/fusions/fusion_emitter.cc index 348f2286842781..432d600701d1ab 100644 --- a/xla/service/gpu/fusions/fusion_emitter.cc +++ b/xla/service/gpu/fusions/fusion_emitter.cc @@ -193,6 +193,12 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdIndexingMap( return indexing_map; } +std::string GetSanitizedUniqueName(IrEmitterContext& ir_emitter_context, + const std::string& suggested_name) { + return ir_emitter_context.name_uniquer()->GetUniqueName( + llvm_ir::SanitizeFunctionName(suggested_name)); +} + absl::StatusOr, std::vector>> BuildKernelPrototype(IrEmitterContext& ir_emitter_context, @@ -201,6 +207,20 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, size_t num_inputs, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* builder) { + return BuildKernelPrototypeFromUniqueName( + ir_emitter_context, + GetSanitizedUniqueName(ir_emitter_context, suggested_name), arguments, + num_inputs, launch_dimensions, builder); +} + +absl::StatusOr, + std::vector>> +BuildKernelPrototypeFromUniqueName(IrEmitterContext& ir_emitter_context, + const std::string& unique_kernel_name, + absl::Span arguments, + size_t num_inputs, + const LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* builder) { // If some arguments have the same buffer, we will pass them only once. llvm::SmallVector to_llvm_arg_no(arguments.size()); llvm::SmallVector to_arg_no; @@ -217,11 +237,6 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, } const int kNumLlvmArgs = to_arg_no.size(); - // Compute the kernel name. The opcode string may contain "-" which cannot be - // in a PTX function name, so sanitize the name before uniquifying it. - std::string kernel_name = ir_emitter_context.name_uniquer()->GetUniqueName( - llvm_ir::SanitizeFunctionName(suggested_name)); - // Create the kernel and add it to the module. auto* llvm_module = ir_emitter_context.llvm_module(); llvm::LLVMContext& context = llvm_module->getContext(); @@ -233,12 +248,12 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, - kernel_name, llvm_module); + unique_kernel_name, llvm_module); AnnotateFunctionAsGpuKernel(llvm_module, kernel, builder); TF_RETURN_IF_ERROR(AnnotateKernelLaunchDimensions( - ir_emitter_context.gpu_device_info(), launch_dimensions, kernel_name, - llvm_module)); + ir_emitter_context.gpu_device_info(), launch_dimensions, + unique_kernel_name, llvm_module)); // TODO(b/65380986): Investigate if adding fast math flags for generated // kernels makes sense. diff --git a/xla/service/gpu/fusions/fusion_emitter.h b/xla/service/gpu/fusions/fusion_emitter.h index 13dbd034e33314..f929db1860e970 100644 --- a/xla/service/gpu/fusions/fusion_emitter.h +++ b/xla/service/gpu/fusions/fusion_emitter.h @@ -134,6 +134,20 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, size_t num_inputs, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* builder); +absl::StatusOr< + std::tuple, + std::vector /*outputs*/>> +BuildKernelPrototypeFromUniqueName(IrEmitterContext& ir_emitter_context, + const std::string& unique_name, + absl::Span arguments, + size_t num_inputs, + const LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* builder); + +// Compute the kernel name. The opcode string may contain "-" which cannot be +// in a PTX function name, so sanitize the name before uniquifying it. +std::string GetSanitizedUniqueName(IrEmitterContext& ir_emitter_context, + const std::string& suggested_name); absl::Status AnnotateKernelLaunchDimensions( const se::DeviceDescription& device_info, diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 13bf68ed1789df..430fbca28a4389 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -2840,7 +2840,7 @@ absl::StatusOr CompileTritonToLLVM( const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mlir::ModuleOp triton_module, llvm::Module* llvm_module, - mlir::MLIRContext& mlir_context) { + mlir::MLIRContext& mlir_context, bool emit_kernel) { if (std::holds_alternative(cc)) { auto ccCuda = std::get(cc); if (!ccCuda.IsAtLeastAmpere()) { @@ -2935,27 +2935,30 @@ absl::StatusOr CompileTritonToLLVM( shared_mem_bytes, device_info.shared_memory_per_block_optin())); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr ll_triton_module, - TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module, - GetLibdevicePath(hlo_config, device_info))); - VLogModule(5, *ll_triton_module); - if (should_verify) { - VerifyModule(*ll_triton_module); - } - - // Integrate LLVM matmul kernel into XLA's LLVM module. - ll_triton_module->eraseNamedMDNode( - ll_triton_module->getNamedMetadata("nvvm.annotations")); - ll_triton_module->setDataLayout(llvm_module->getDataLayout()); - ll_triton_module->setTargetTriple(llvm_module->getTargetTriple()); - // Use override flag because libdevice functions can be present in both. - TF_RET_CHECK( - !llvm::Linker::linkModules(*llvm_module, std::move(ll_triton_module), - llvm::Linker::Flags::OverrideFromSrc)); - VLogModule(5, *llvm_module); - if (should_verify) { - VerifyModule(*llvm_module); + if (emit_kernel) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr ll_triton_module, + TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module, + GetLibdevicePath(hlo_config, device_info))); + VLogModule(5, *ll_triton_module); + if (should_verify) { + VerifyModule(*ll_triton_module); + } + + // Integrate LLVM matmul kernel into XLA's LLVM module. + ll_triton_module->eraseNamedMDNode( + ll_triton_module->getNamedMetadata("nvvm.annotations")); + ll_triton_module->setDataLayout(llvm_module->getDataLayout()); + ll_triton_module->setTargetTriple(llvm_module->getTargetTriple()); + // Use override flag because libdevice functions can be present in both. + TF_RET_CHECK( + !llvm::Linker::linkModules(*llvm_module, std::move(ll_triton_module), + llvm::Linker::Flags::OverrideFromSrc)); + + VLogModule(5, *llvm_module); + if (should_verify) { + VerifyModule(*llvm_module); + } } // `cluster_info` must be read after pm.run(). diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.h b/xla/service/gpu/fusions/triton/triton_fusion_emitter.h index d28101c66d9dc4..eb9644808fc6b6 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.h +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.h @@ -121,13 +121,16 @@ absl::StatusOr> CreateTritonModule( mlir::MLIRContext& mlir_context); // Compiles a given Triton module to LLVM IR. +// If `emit_kernels` is false, then the function skips emitting +// the kernels, but it still returns correctly filled TritonWrapperResult. +// That is useful when deserializing from the compilation cache. absl::StatusOr CompileTritonToLLVM( const HloModuleConfig& hlo_config, absl::string_view hlo_module_name, const se::GpuComputeCapability& cc, const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mlir::ModuleOp triton_module, llvm::Module* llvm_module, - mlir::MLIRContext& mlir_context); + mlir::MLIRContext& mlir_context, bool emit_kernel = true); // Create Triton pipeline. // diff --git a/xla/service/gpu/gpu_aot_compilation_test.cc b/xla/service/gpu/gpu_aot_compilation_test.cc index aad47e75728e27..945f63a1f87c0d 100644 --- a/xla/service/gpu/gpu_aot_compilation_test.cc +++ b/xla/service/gpu/gpu_aot_compilation_test.cc @@ -21,10 +21,12 @@ limitations under the License. #include #include "absl/strings/ascii.h" #include "absl/strings/string_view.h" +#include "mlir/IR/Builders.h" // from @llvm-project #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" @@ -121,5 +123,129 @@ ENTRY main { aot_result->LoadExecutable(compiler, stream_exec)); } +namespace { + +using ::mlir::ArrayRef; +using ::mlir::NamedAttribute; + +std::string CreateTritonCustomCallBackendConfig() { + mlir::MLIRContext context_; + mlir::Builder builder(&context_); + + // Create the backend_config for the triton custom call. + const std::string kMLIRText = R"( + module { + tt.func public @add_one(%arg0: !tt.ptr {tt.divisibility = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 32 : i32}, %arg3: !tt.ptr {tt.divisibility = 32 : i32}) { + %0 = tt.get_program_id x : i32 + %1 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr + %2 = tt.load %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr + %cst = arith.constant 1.000000e+00 : f32 + %3 = arith.addf %1, %cst : f32 + tt.store %arg2, %3 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr + tt.store %arg3, %2 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr + tt.return + } + } + )"; + + NamedAttribute name = + builder.getNamedAttr("name", builder.getStringAttr("add_one")); + NamedAttribute ir = + builder.getNamedAttr("ir", builder.getStringAttr(kMLIRText)); + NamedAttribute num_stages = + builder.getNamedAttr("num_stages", builder.getI32IntegerAttr(3)); + NamedAttribute num_warps = + builder.getNamedAttr("num_warps", builder.getI32IntegerAttr(4)); + NamedAttribute grid_x = + builder.getNamedAttr("grid_x", builder.getI32IntegerAttr(1)); + NamedAttribute grid_y = + builder.getNamedAttr("grid_y", builder.getI32IntegerAttr(1)); + NamedAttribute grid_z = + builder.getNamedAttr("grid_z", builder.getI32IntegerAttr(1)); + NamedAttribute debug = + builder.getNamedAttr("debug", builder.getBoolAttr(false)); + + std::vector attributes = { + name, ir, num_stages, num_warps, grid_x, grid_y, grid_z, debug}; + ArrayRef attributesRef(attributes); + mlir::DictionaryAttr backend_config = + mlir::DictionaryAttr::get(&context_, attributesRef); + + // Parse the backend_config into a string. + std::string backend_config_str; + llvm::raw_string_ostream(backend_config_str) << backend_config; + + return backend_config_str; +} + +} // namespace + +TEST_F(GpuAotCompilationTest, ExportAndLoadExecutableWithTriton) { + auto triton_support = + EnsureTritonSupportsComputeCapability(backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability()); + if (!triton_support.ok()) { + GTEST_SKIP() << triton_support; + } + + const absl::string_view hlo_string_template = R"( + HloModule Test + + ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = (f32[],f32[]) custom-call(a, b), custom_call_target="__gpu$xla.gpu.triton", backend_config="%s" + } + )"; + + std::string hlo_string = + absl::StrFormat(hlo_string_template, + absl::CEscape(CreateTritonCustomCallBackendConfig())); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto compiler = backend().compiler(); + auto platform_name = + absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + se::PlatformManager::PlatformWithName(platform_name)); + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, + platform->ExecutorForDevice(0)); + + // Compile AOT. + auto module_group = std::make_unique(std::move(module)); + AotCompilationOptions aot_options(compiler->PlatformId()); + aot_options.set_executor(stream_exec); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector> aot_results, + compiler->CompileAheadOfTime(std::move(module_group), aot_options)); + + // Serialize-deserialize AOT compilation result. + TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result, + aot_results[0]->SerializeAsString()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr aot_result, + compiler->LoadAotCompilationResult(serialized_aot_result)); + + // Load Executable from AOT compilation result. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + aot_result->LoadExecutable(compiler, stream_exec)); + + const xla::Literal literal_1 = xla::LiteralUtil::CreateR0(1.0f); + const xla::Literal literal_2 = xla::LiteralUtil::CreateR0(2.0f); + const xla::Literal literal_3 = xla::LiteralUtil::CreateR0(3.0f); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, + GetHloRunner().value()->ExecuteWithExecutable( + executable.get(), {&literal_1, &literal_3})); + + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::MakeTuple({&literal_2, &literal_3}), result)); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index c094e69a601b67..c102feaf76cdf4 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -1603,8 +1603,15 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( auto triton_fn = triton_module->lookupSymbol(call.name); triton_fn.setName(kernel_name); + size_t arg_size = triton_fn.getNumArguments(); HloModule* hlo_module = instr->GetModule(); + // If emit_kernels if false (i.e., when deserializing an already compiled + // executable), we do not emit code, but we still need to run part of the + // compiler to figure out the size of the shared memory and the cluster + // dimensions for the thunk. We also must call the name uniqifier as if + // emitting code so that the future generated names remain in sync. + bool emit_kernels = ir_emitter_context_->emit_kernels(); BlockLevelParameters block_level_parameters; block_level_parameters.num_stages = call.num_stages; @@ -1617,13 +1624,8 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( ir_emitter_context_->gpu_compute_capability(), ir_emitter_context_->gpu_device_info(), block_level_parameters, triton_module.get(), - ir_emitter_context_->llvm_module(), mlir_context)); - - llvm::Function* impl_fn = - ir_emitter_context_->llvm_module()->getFunction(kernel_name); - TF_RET_CHECK(impl_fn); - impl_fn->setName(ir_emitter_context_->name_uniquer()->GetUniqueName( - kernel_name + "_impl")); + ir_emitter_context_->llvm_module(), mlir_context, + emit_kernels)); TF_ASSIGN_OR_RETURN( auto kernel_arguments, @@ -1634,33 +1636,44 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( LaunchDimensions(se::BlockDim(call.grid_x, call.grid_y, call.grid_z), se::ThreadDim(call.num_warps * 32)); - llvm::IRBuilder builder(ir_emitter_context_->llvm_module()->getContext()); - - llvm::Function* kernel; - std::vector inputs; - std::vector outputs; - TF_ASSIGN_OR_RETURN( - std::tie(kernel, inputs, outputs), - BuildKernelPrototype(*ir_emitter_context_, kernel_name, - kernel_arguments.args(), impl_fn->arg_size(), - launch_dimensions, &builder)); - - // Move function body into kernel prototype. - llvm::Function* prototype_func = builder.GetInsertBlock()->getParent(); - prototype_func->splice(prototype_func->begin(), impl_fn); - for (const auto& [arg, input] : llvm::zip(impl_fn->args(), inputs)) { - arg.replaceAllUsesWith(input.GetBasePointer()); - } - impl_fn->eraseFromParent(); + std::string sanitized_kernel_name = + GetSanitizedUniqueName(*ir_emitter_context_, kernel_name); + + if (emit_kernels) { + llvm::Function* impl_fn = + ir_emitter_context_->llvm_module()->getFunction(kernel_name); + TF_RET_CHECK(impl_fn); + impl_fn->setName(ir_emitter_context_->name_uniquer()->GetUniqueName( + kernel_name + "_impl")); + + llvm::IRBuilder builder(ir_emitter_context_->llvm_module()->getContext()); + + llvm::Function* kernel; + std::vector inputs; + std::vector outputs; + TF_ASSIGN_OR_RETURN( + std::tie(kernel, inputs, outputs), + BuildKernelPrototypeFromUniqueName( + *ir_emitter_context_, sanitized_kernel_name, + kernel_arguments.args(), arg_size, launch_dimensions, &builder)); + + // Move function body into kernel prototype. + llvm::Function* prototype_func = builder.GetInsertBlock()->getParent(); + prototype_func->splice(prototype_func->begin(), impl_fn); + for (const auto& [arg, input] : llvm::zip(impl_fn->args(), inputs)) { + arg.replaceAllUsesWith(input.GetBasePointer()); + } + impl_fn->eraseFromParent(); - for (auto& arg : prototype_func->args()) { - // Remove the alignment and aliasing attributes to avoid recompiling the - // kernel for each alignment/aliasing combination. - arg.removeAttr(llvm::Attribute::Alignment); - arg.removeAttr(llvm::Attribute::NoAlias); + for (auto& arg : prototype_func->args()) { + // Remove the alignment and aliasing attributes to avoid recompiling the + // kernel for each alignment/aliasing combination. + arg.removeAttr(llvm::Attribute::Alignment); + arg.removeAttr(llvm::Attribute::NoAlias); + } } - return {{kernel->getName().str(), launch_dimensions, result.cluster_dim, + return {{sanitized_kernel_name, launch_dimensions, result.cluster_dim, result.shmem_bytes}}; }; From de05f29c2736878296f15c7f0dab84b6ffd4a613 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jul 2024 04:16:00 -0700 Subject: [PATCH 053/376] Automated Code Change PiperOrigin-RevId: 654690361 --- xla/service/gpu/launch_dimensions.cc | 10 ---------- xla/service/gpu/launch_dimensions.h | 18 ------------------ 2 files changed, 28 deletions(-) diff --git a/xla/service/gpu/launch_dimensions.cc b/xla/service/gpu/launch_dimensions.cc index b31b0e532c1999..89b322f6708556 100644 --- a/xla/service/gpu/launch_dimensions.cc +++ b/xla/service/gpu/launch_dimensions.cc @@ -32,16 +32,6 @@ limitations under the License. namespace xla { namespace gpu { -std::ostream& operator<<(std::ostream& out, - const LaunchDimensions& launch_dims) { - se::BlockDim block_counts = launch_dims.block_counts(); - se::ThreadDim thread_counts = launch_dims.thread_counts_per_block(); - out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]", - block_counts.x, block_counts.y, block_counts.z, - thread_counts.x, thread_counts.y, thread_counts.z); - return out; -} - static int64_t ThreadsPerBlockLimit( const se::DeviceDescription& gpu_device_info) { int64_t threads_per_block = gpu_device_info.threads_per_block_limit(); diff --git a/xla/service/gpu/launch_dimensions.h b/xla/service/gpu/launch_dimensions.h index 0d38013657ef5e..e0c53f9b266f4c 100644 --- a/xla/service/gpu/launch_dimensions.h +++ b/xla/service/gpu/launch_dimensions.h @@ -76,23 +76,11 @@ class LaunchDimensions { thread_counts_per_block_.z, "}"); } - bool operator==(const LaunchDimensions& other) const { - return block_counts_ == other.block_counts_ && - thread_counts_per_block_ == other.thread_counts_per_block_; - } - - bool operator!=(const LaunchDimensions& other) const { - return !(*this == other); - } - private: se::BlockDim block_counts_; se::ThreadDim thread_counts_per_block_; }; -std::ostream& operator<<(std::ostream& out, - const LaunchDimensions& launch_dims); - struct LaunchDimensionsConfig { // The kernel implementation will be unrolled if `unroll_factor` is // greater than one. @@ -108,12 +96,6 @@ struct LaunchDimensionsConfig { // `hlo.shape().dimensions().back()/unroll_factor`. // Currently few_waves and row_vectorized do not work together. bool row_vectorized = false; - - std::string ToString() { - return absl::StrCat("unroll_factor=", unroll_factor, - ", few_waves=", few_waves, - ", row_vectorized=", row_vectorized); - } }; // Returns -1 if the shape doesn't allow the row vectorization code path. From f169779a8c80625e16029d1456953f46f52c547a Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 22 Jul 2024 04:56:02 -0700 Subject: [PATCH 054/376] Fix vectorization of tiny multi-row reductions. For these we can attempt to use a vectorization factor greater than the row length, which is not something we currently support in codegen. PiperOrigin-RevId: 654698797 --- xla/service/gpu/fusions/reduction_mlir.cc | 7 +++++- .../gpu/fusions/reduction_mlir_test.cc | 23 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/fusions/reduction_mlir.cc b/xla/service/gpu/fusions/reduction_mlir.cc index 8250c408edf619..d58f809b1c3c41 100644 --- a/xla/service/gpu/fusions/reduction_mlir.cc +++ b/xla/service/gpu/fusions/reduction_mlir.cc @@ -742,7 +742,12 @@ MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( // This vector size is always valid: we know that the reduced dimension is a // power of 2, since otherwise RowReductionGetRowsPerWarp would have // returned 1. - int vector_size = 32 / smallest_input_or_output_bits; + // Our codegen can't currently deal with vectorization across rows, so we + // limit the vector size to the size of the row. Note that this emitter + // essentially reverts to the loop emitter in this case, except for side + // outputs. + int vector_size = std::min(static_cast(input_shape_[kRowMinorReduced]), + 32 / smallest_input_or_output_bits); // We target 8 warps per block, which means there could be up to 8 blocks per // SM, but we have no good way of knowing. In practice, enabling vectorization diff --git a/xla/service/gpu/fusions/reduction_mlir_test.cc b/xla/service/gpu/fusions/reduction_mlir_test.cc index 8fc0ba6af05bb7..214e9b582cb123 100644 --- a/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -161,6 +161,24 @@ constexpr auto kMultiRowReductionX2VectorX4 = R"( ROOT fusion = (pred[76800]{0}, pred[76800]{0}) fusion(p0, p1), kind=kInput, calls=fusion })"; +constexpr auto kMultiRowReductionX16VectorX2 = R"( + or { + tmp_0 = pred[] parameter(0) + tmp_1 = pred[] parameter(1) + ROOT tmp_2 = pred[] or(tmp_0, tmp_1) + } + + fusion { + p0 = pred[76800,2] parameter(0) + c0 = pred[] constant(false) + ROOT reduce = pred[76800] reduce(p0, c0), dimensions={1}, to_apply=or + } + + ENTRY main { + p0 = pred[76800,2] parameter(0) + ROOT fusion = pred[76800] fusion(p0), kind=kInput, calls=fusion + })"; + constexpr std::string_view kRowReductionSideOutput = R"( Add { lhs = f32[] parameter(0) @@ -855,6 +873,11 @@ TEST_F(MlirMultiRowReductionTest, VectorizedX4Indexing) { ElementsAre(1 /* major reduced */, 4 /* vector size */)); } +TEST_F(MlirMultiRowReductionTest, LimitedVectorizationCorrectness) { + EXPECT_TRUE( + RunAndCompareNoHloPasses(kMultiRowReductionX16VectorX2, ErrorSpec{1e-3})); +} + TEST_F(MlirMultiRowReductionTest, VectorizedX4Correctness) { EXPECT_TRUE( RunAndCompareNoHloPasses(kMultiRowReductionX2VectorX4, ErrorSpec{1e-3})); From 7a3c9dfde5516f2b1095d818095e83267498d88c Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 22 Jul 2024 04:57:48 -0700 Subject: [PATCH 055/376] Don't skip materialization of indices for some selects. If the select is not really elementwise, we just materialize the indices. This is very rare, so keeping the code reasonably simple is more important than saving all possible materializations. PiperOrigin-RevId: 654699120 --- .../gpu/fusions/mlir/elemental_hlo_to_mlir.cc | 15 +++++++++++++-- .../mlir/elemental_hlo_to_mlir_test.cc | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index fa3308b28714e3..aef5d7139bdd89 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -788,8 +788,19 @@ absl::StatusOr> GetOperands( const HloInstruction* instr, ValueRange indices, const OperandProvider& operand_provider, ImplicitLocOpBuilder& builder) { SmallVector operands; - if (HloInstruction::IsOpElementwise(instr->opcode()) || - instr->opcode() == HloOpcode::kMap) { + bool is_elementwise = HloInstruction::IsOpElementwise(instr->opcode()) || + instr->opcode() == HloOpcode::kMap; + if (is_elementwise && instr->shape().IsArray()) { + // Check if the instruction is really elementwise. There may be some + // broadcasting. + int64_t rank = instr->shape().rank(); + is_elementwise &= + absl::c_all_of(instr->operands(), [&](const HloInstruction* operand) { + return operand->shape().rank() == rank; + }); + } + + if (is_elementwise) { // Avoid materializing the input indices for elementwise ops. for (int64_t operand_number = 0; operand_number < instr->operand_count(); ++operand_number) { diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index de0413881c7d79..3b03393788538a 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -1645,6 +1645,25 @@ TEST_F(ElementalHloToMlirTest, Map) { )")); } +TEST_F(ElementalHloToMlirTest, BroadcastSelect) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = pred[] parameter(0) + p1 = f32[5,7] parameter(1) + p2 = f32[5,7] parameter(2) + ROOT r = f32[5,7] select(p0, p1, p2) + })", + R"( + // CHECK: @main + // CHECK-SAME: %[[P0:.*]]: tensor + // CHECK-SAME: %[[P1:.*]]: tensor<5x7xf32>, %[[P2:.*]]: tensor<5x7xf32> + // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} + // CHECK-DAG: tensor.extract %[[P0]][] + // CHECK-DAG: tensor.extract %[[P1]][%[[X]], %[[Y]]] + // CHECK-DAG: tensor.extract %[[P2]][%[[X]], %[[Y]]] + )")); +} + } // namespace } // namespace mlir_converter } // namespace gpu From fc86ad4d4c488f73075b385714cf77e1c26bfb43 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 22 Jul 2024 05:39:00 -0700 Subject: [PATCH 056/376] PR #15176: [GPU] Fix use of DfsHloRewriteVisitor by cuDNN fusion compiler. Imported from GitHub PR https://github.com/openxla/xla/pull/15176 DfsHloVisitor visits replaced instructions again. This used to happen here on addition of workspace, HandleFusion() was called again. SetVisited() prevents that. Copybara import of the project: -- 25b0a8bd982bd47bb87c9baab79382f09127f2fa by Ilia Sergachev : [GPU] Fix use of DfsHloRewriteVisitor by cuDNN fusion compiler. DfsHloVisitor visits replaced instructions again. This used to happen here on addition of workspace, HandleFusion() was called again. SetVisited() prevents that. Merging this change closes #15176 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15176 from openxla:fix_cudnn_compiler 25b0a8bd982bd47bb87c9baab79382f09127f2fa PiperOrigin-RevId: 654708499 --- xla/service/gpu/cudnn_fusion_compiler.cc | 14 +++++---- xla/service/gpu/fusions/cudnn_test.cc | 37 ++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/xla/service/gpu/cudnn_fusion_compiler.cc b/xla/service/gpu/cudnn_fusion_compiler.cc index 80a3079d640493..01fd722438f333 100644 --- a/xla/service/gpu/cudnn_fusion_compiler.cc +++ b/xla/service/gpu/cudnn_fusion_compiler.cc @@ -609,9 +609,6 @@ absl::StatusOr PrepareGraph( absl::StatusOr AddWorkspace(HloInstruction& fusion, const int64_t workspace_size) { - if (workspace_size == 0 || fusion.shape().IsTuple()) { - return &fusion; - } HloComputation* computation = fusion.fused_instructions_computation(); HloInstruction* custom_call = computation->AddInstruction(HloInstruction::CreateCustomCall( @@ -650,6 +647,13 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { VLOG(4) << "Processing " << hlo->ToString(); VLOG(4) << "Plan ID: " << plan_id; + auto add_workspace = [&](const int64_t workspace_size) { + if (workspace_size > 0) { + TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size)); + SetVisited(*hlo); + } + return absl::OkStatus(); + }; const std::string fingerprint_without_workspace = GetComputationFingerprint(hlo->fused_instructions_computation(), {}); auto workspace_size_it = @@ -683,7 +687,7 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { const int64_t workspace_size = graph.Graph().get_workspace_size(); workspace_sizes_.insert(workspace_size_it, {fingerprint_without_workspace, workspace_size}); - TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size)); + TF_RETURN_IF_ERROR(add_workspace(workspace_size)); std::vector serialized_graph; RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph)); @@ -695,7 +699,7 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { serialized_graph.size()); } else { VLOG(4) << "Cache hit."; - TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size_it->second)); + TF_RETURN_IF_ERROR(add_workspace(workspace_size_it->second)); } auto cudnn_config = gpu_config.mutable_fusion_backend_config() ->mutable_cudnn_fusion_config(); diff --git a/xla/service/gpu/fusions/cudnn_test.cc b/xla/service/gpu/fusions/cudnn_test.cc index 1682d57df52032..a8d5471784c178 100644 --- a/xla/service/gpu/fusions/cudnn_test.cc +++ b/xla/service/gpu/fusions/cudnn_test.cc @@ -201,6 +201,43 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +TEST_F(CuDnnFusionExecutionTest, + CuDnnFusionCompilerDoesNotFailOnDependentFusions) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +c1 { + p0 = f32[32,96] parameter(0) + p1 = f32[96,64] parameter(1) + ROOT r = f32[32,64] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +c2 { + p0 = f32[32,96] parameter(0) + p1 = f32[32,64] parameter(1) + ROOT r = f32[96,64] dot(p0, p1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[32,96] parameter(0) + p1 = f32[96,64] parameter(1) + f0 = f32[32,64] fusion(p0, p1), kind=kCustom, calls=c1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}} + f1 = f32[96,64] fusion(p0, f0), kind=kCustom, calls=c2, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}} + ROOT r = tuple(f0, f1) +})")); + BinaryMap dnn_compiled_graphs; + CuDnnFusionCompiler cudnn_compiler(*backend().default_stream_executor(), + dnn_compiled_graphs); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cudnn_compiler.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::GetTupleElement(m::Fusion()), + m::GetTupleElement(m::Fusion())))); +} + TEST_F(CuDnnFusionExecutionTest, NoTritonConfigIsAssignedAtZeroAutotuningLevel) { EXPECT_EQ(GetDebugOptionsForTest().xla_gpu_autotune_level(), 0); From 9a29e841539e4be651685ef8bcc987acd790f125 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 22 Jul 2024 05:57:28 -0700 Subject: [PATCH 057/376] Prepare tests for MLIR reduction emitter launch. - hlo tests just hard code the emitter level to 0. We can adjust these when we remove the flag (or remove them / move them to cc tests) - tests that depend on MOF were adjusted to the new IR, except for one that tests something that does not occur in real pipelines. I'm not really sure what that test is trying to test - probably side outputs, which are covered in the unit test. - one test that verified a failure to vectorize was disabled PiperOrigin-RevId: 654712419 --- .../gpu/horizontal_input_fusion_test.cc | 17 +++-- xla/service/gpu/tests/BUILD | 1 + .../gpu/tests/parallel_reduction_test.cc | 68 +++++++++++++------ .../tests/reduction_vectorization_sm_all.hlo | 8 +-- .../gpu/tests/reduction_vectorization_test.cc | 3 + xla/tools/hlo_opt/gpu_hlo_llvm.hlo | 2 +- 6 files changed, 68 insertions(+), 31 deletions(-) diff --git a/xla/service/gpu/horizontal_input_fusion_test.cc b/xla/service/gpu/horizontal_input_fusion_test.cc index 2839e2ff40c0e7..2d458f9db452d1 100644 --- a/xla/service/gpu/horizontal_input_fusion_test.cc +++ b/xla/service/gpu/horizontal_input_fusion_test.cc @@ -144,11 +144,18 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) { module->AddEntryComputation(builder.Build()); // Verify that horizontal fusion is kicked in. Check that there are multiple - // `reduce` instructions fused into the same fusion. 6 is just a randomly - // picked number as we don't exactly know how large the fusion will be - // created due to the `FusionFitsInBudget` constraint. - CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)", - /*match_optimized_ir=*/false); + // `reduce` instructions fused into the same fusion. + if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() < 4) { + // 6 is just a randomly picked number as we don't exactly know how large the + // fusion will be created due to the `FusionFitsInBudget` constraint. + CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)", + /*match_optimized_ir=*/false); + } else { + // Verify that we produced a multi-output reduction with independent groups. + CompileAndVerifyIr(module->Clone(), R"(CHECK: switch {{.*}} label {{.*}} [ + CHECK-NEXT: label)", + /*match_optimized_ir=*/false); + } // Testing with the entire gpu optimization pipeline. EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index df264c618d888b..83daaf0eb540a6 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -403,6 +403,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], diff --git a/xla/service/gpu/tests/parallel_reduction_test.cc b/xla/service/gpu/tests/parallel_reduction_test.cc index 30e37759c0519f..2e0a975bbfca9d 100644 --- a/xla/service/gpu/tests/parallel_reduction_test.cc +++ b/xla/service/gpu/tests/parallel_reduction_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -36,6 +37,7 @@ namespace gpu { namespace { class ParallelReductionTest : public GpuCodegenTest { + protected: DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // The test contains a MOF fusion and the XLA optimizer passes @@ -74,13 +76,20 @@ ENTRY %cluster { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_text)); - CompileAndVerifyIr(std::move(hlo_module), - R"( -CHECK: reduce-group-0 -CHECK: reduce-group-1 -CHECK-NOT: reduce-group-2 -)", - /*match_optimized_ir=*/false); + + if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { + CompileAndVerifyIr(std::move(hlo_module), + R"(CHECK: switch {{.*}} label {{.*}} [ + CHECK-NEXT: label + CHECK-NEXT: ])", + /*match_optimized_ir=*/false); + } else { + CompileAndVerifyIr(std::move(hlo_module), + R"(CHECK: reduce-group-0 + CHECK: reduce-group-1 + CHECK-NOT: reduce-group-2)", + /*match_optimized_ir=*/false); + } EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } @@ -115,18 +124,28 @@ ENTRY %cluster { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_text)); - CompileAndVerifyIr(std::move(hlo_module), - R"( -CHECK: reduce-group-0 -CHECK: reduce-group-1 -CHECK-NOT: reduce-group-2 -)", - /*match_optimized_ir=*/false); + if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { + CompileAndVerifyIr(std::move(hlo_module), + R"(CHECK: switch {{.*}} label {{.*}} [ + CHECK-NEXT: label + CHECK-NEXT: ])", + /*match_optimized_ir=*/false); + } else { + CompileAndVerifyIr(std::move(hlo_module), + R"(CHECK: reduce-group-0 + CHECK: reduce-group-1 + CHECK-NOT: reduce-group-2)", + /*match_optimized_ir=*/false); + } EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } TEST_F(ParallelReductionTest, UnnestedReductionWithLoopReductionDifferentShape) { + if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { + GTEST_SKIP() + << "reduction does not occur in real pipelines and is not supported"; + } const char* hlo = R"( HloModule module @@ -346,13 +365,20 @@ ENTRY %cluster { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_text)); - CompileAndVerifyIr(std::move(hlo_module), - R"( -CHECK: reduce-group-0 -CHECK: reduce-group-1 -CHECK-NOT: reduce-group-2 -)", - /*match_optimized_ir=*/false); + + if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { + CompileAndVerifyIr(std::move(hlo_module), + R"(CHECK: switch {{.*}} label {{.*}} [ + CHECK-NEXT: label + CHECK-NEXT: ])", + /*match_optimized_ir=*/false); + } else { + CompileAndVerifyIr(std::move(hlo_module), + R"(CHECK: reduce-group-0 + CHECK: reduce-group-1 + CHECK-NOT: reduce-group-2)", + /*match_optimized_ir=*/false); + } EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } diff --git a/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo b/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo index 6a25580a4bcff9..baeb614b18d6e1 100644 --- a/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo +++ b/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo @@ -1,7 +1,7 @@ -// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s -// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/p100.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM60 -// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/v100.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM70 -// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/a6000.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM86 +// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/p100.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM60 +// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/v100.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM70 +// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/a6000.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM86 // CHECK-LABEL: .entry wrapped_reduce_odd_row // CHECK-NOT: ld.global.nc.v2.f32 diff --git a/xla/service/gpu/tests/reduction_vectorization_test.cc b/xla/service/gpu/tests/reduction_vectorization_test.cc index e4e5845e018d6e..680391c2fa7db6 100644 --- a/xla/service/gpu/tests/reduction_vectorization_test.cc +++ b/xla/service/gpu/tests/reduction_vectorization_test.cc @@ -107,6 +107,9 @@ CHECK: st.global.v2.f32 } TEST_F(ReductionVectorizationTest, NoVectorizationForBlockSmallerThanWarpSize) { + if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { + GTEST_SKIP() << "MLIR emitters can vectorize this"; + } const char* hlo_text = R"( HloModule SlowModule diff --git a/xla/tools/hlo_opt/gpu_hlo_llvm.hlo b/xla/tools/hlo_opt/gpu_hlo_llvm.hlo index 59800a9d170560..e323b0c1930dfa 100644 --- a/xla/tools/hlo_opt/gpu_hlo_llvm.hlo +++ b/xla/tools/hlo_opt/gpu_hlo_llvm.hlo @@ -1,4 +1,4 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s +// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s HloModule m From 737a7da3c5405583dc95773ac0bb11b1349fc9ea Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Mon, 22 Jul 2024 07:30:02 -0700 Subject: [PATCH 058/376] [XLA:GPU] Add accuracy checker for PGLE on GPU. PiperOrigin-RevId: 654740137 --- xla/service/BUILD | 8 + xla/service/gpu/BUILD | 18 +- xla/service/gpu/gpu_hlo_schedule.cc | 8 +- .../gpu/gpu_latency_hiding_scheduler.cc | 60 ++++- .../gpu/gpu_latency_hiding_scheduler.h | 19 ++ .../gpu/gpu_latency_hiding_scheduler_test.cc | 254 ++++++++++++++++++ .../profile_guided_latency_estimator.cc | 108 +++++++- .../profile_guided_latency_estimator.h | 51 +++- .../profile_guided_latency_estimator_test.cc | 37 ++- 9 files changed, 545 insertions(+), 18 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 8857ab34bb12c9..c2d9afa0914f5d 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1524,7 +1524,11 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", ], @@ -1540,7 +1544,11 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", ], diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 0d8dd38d751cbe..4fd126613b4791 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -6015,6 +6015,7 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", "//xla/service:latency_hiding_scheduler", + "//xla/service:profile_guided_latency_estimator", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", @@ -6024,5 +6025,20 @@ cc_library( xla_cc_test( name = "gpu_latency_hiding_scheduler_test", srcs = ["gpu_latency_hiding_scheduler_test.cc"], - deps = ["//xla/tests:xla_internal_test_main"], + deps = [ + ":gpu_hlo_schedule", + ":gpu_latency_hiding_scheduler", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:profile_guided_latency_estimator", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], ) diff --git a/xla/service/gpu/gpu_hlo_schedule.cc b/xla/service/gpu/gpu_hlo_schedule.cc index 0463875b4f98a4..c856cdf468c3f6 100644 --- a/xla/service/gpu/gpu_hlo_schedule.cc +++ b/xla/service/gpu/gpu_hlo_schedule.cc @@ -507,8 +507,10 @@ absl::StatusOr ScheduleGpuModule( .debug_options() .xla_gpu_enable_analytical_latency_estimator(); if (profile.has_value()) { - latency_estimator = std::make_unique( - config, std::move(gpu_latency_estimator), profile.value()); + auto aggregator = std::make_unique(); + auto pg_latency_estimator = std::make_unique( + config, std::move(gpu_latency_estimator), profile.value(), + std::move(aggregator)); LOG(INFO) << "Found profile, using profile guided latency estimator. Profile:\n" << profile->DebugString(); @@ -518,6 +520,8 @@ absl::StatusOr ScheduleGpuModule( "still be used : " << s.message(); } + TF_RETURN_IF_ERROR(pg_latency_estimator->CheckAccuracy(*module)); + latency_estimator = std::move(pg_latency_estimator); } else if (enable_analytical_latency_estimator) { latency_estimator = std::make_unique( config, std::move(gpu_latency_estimator), gpu_device_info, diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/xla/service/gpu/gpu_latency_hiding_scheduler.cc index 4ecb194121022b..fcd92f6799c9bd 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -48,7 +48,7 @@ bool IsNopInstruction(const HloInstruction& hlo) { HloOpcode op = hlo.opcode(); return op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast || op == HloOpcode::kConstant || op == HloOpcode::kParameter || - hlo.IsEffectiveBitcast(); + op == HloOpcode::kTuple || hlo.IsEffectiveBitcast(); } bool IsAsyncComputeOp(const HloInstruction& hlo) { @@ -89,6 +89,24 @@ std::pair GetP2PResourceAndUsage( return {resource, usage}; } +bool IsGpuAsyncStart(const HloInstruction& hlo) { + return (hlo_query::IsAsyncCollectiveStartOp(&hlo, + /*include_send_recv=*/true) && + !IsSyncCollective(&hlo)) || + IsAsyncComputeOp(hlo); +} + +bool IsGpuAsyncDone(const HloInstruction& hlo) { + return (hlo_query::IsAsyncCollectiveDoneOp(&hlo, + /*include_send_recv=*/true) && + !IsSyncCollective(hlo.operand(0))) || + IsAsyncComputeOp(hlo); +} + +bool IsAsyncPair(const HloInstruction& from, const HloInstruction& target) { + return IsGpuAsyncStart(from) && IsGpuAsyncDone(target); +} + } // namespace int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { @@ -125,18 +143,12 @@ GpuAsyncTrackerBase::GpuAsyncTrackerBase(const SchedulerConfig& config, bool GpuAsyncTrackerBase::IsSupportedAsyncDone( const HloInstruction& hlo) const { - return (hlo_query::IsAsyncCollectiveDoneOp(&hlo, - /*include_send_recv=*/true) && - !IsSyncCollective(hlo.operand(0))) || - IsAsyncComputeOp(hlo); + return IsGpuAsyncDone(hlo); } bool GpuAsyncTrackerBase::IsSupportedAsyncStart( const HloInstruction& hlo) const { - return (hlo_query::IsAsyncCollectiveStartOp(&hlo, - /*include_send_recv=*/true) && - !IsSyncCollective(&hlo)) || - IsAsyncComputeOp(hlo); + return IsGpuAsyncStart(hlo); } void GpuAsyncTrackerBase::PostProcessScheduleGraph( @@ -371,7 +383,35 @@ ApproximateLatencyEstimator::TimeCost GpuLatencyEstimator::GetLatencyBetween( // latency between each of them is always one unit. return ApproximateLatencyEstimator::kLowLatency; } -// GpuLatencyEstimator implementations end + +//===--------------------------------------------------------------------===// +// GPUProfileStatisticsAggregator +//===--------------------------------------------------------------------===// + +void GPUProfileStatisticsAggregator::HandleMissingInstructionCost( + const HloInstruction& instruction) { + if (!IsNopInstruction(instruction) && + instruction.opcode() != HloOpcode::kWhile) { + missing_instructions_.insert(&instruction); + } +} + +void GPUProfileStatisticsAggregator::HandleFoundInstructionCost( + const HloInstruction& instruction) { + found_instructions_count_++; +} + +void GPUProfileStatisticsAggregator::HandleMissingInstructionLatency( + const HloInstruction& from, const HloInstruction& to) { + if (IsAsyncPair(from, to)) { + missing_instructions_.insert(&from); + } +} + +void GPUProfileStatisticsAggregator::HandleFoundInstructionLatency( + const HloInstruction& from, const HloInstruction& to) { + found_instructions_count_++; +} } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler.h b/xla/service/gpu/gpu_latency_hiding_scheduler.h index fae9debc8fc291..b0db29c812cb37 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler.h +++ b/xla/service/gpu/gpu_latency_hiding_scheduler.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/latency_hiding_scheduler.h" +#include "xla/service/profile_guided_latency_estimator.h" #include "xla/shape.h" namespace xla { @@ -118,6 +119,24 @@ class GpuLatencyEstimator : public ApproximateLatencyEstimator { int64_t pointer_size_; }; +// GPU PGLE statistics tracker. +class GPUProfileStatisticsAggregator : public ProfileStatisticsAggregator { + public: + // Counts `instruction` as missing if is not a NOP. + void HandleMissingInstructionCost(const HloInstruction& instruction) override; + + // Counts `instruction` as found. + void HandleFoundInstructionCost(const HloInstruction& instruction) override; + + // Counts `from` -> `to` pair as missing if it is an async pair. + void HandleMissingInstructionLatency(const HloInstruction& from, + const HloInstruction& to) override; + + // Counts `from` -> `to` pair as found. + void HandleFoundInstructionLatency(const HloInstruction& from, + const HloInstruction& to) override; +}; + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index 55f4307a0c37e0..5e7f8a754e1aed 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -13,12 +13,266 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/profile_guided_latency_estimator.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + namespace xla::gpu { namespace { +using ::testing::Property; +using ::testing::UnorderedElementsAre; +using ::tsl::testing::StatusIs; + // TODO(b/346918304): Separate relevant tests from gpu_hlo_schedule_test.cc // into broader GPU scheduling related tests vs. tests related to components of // GPU LHS. +class GpuLatencyHidingSchedulerBaseTest : public HloTestBase { + protected: + absl::StatusOr ScheduleModule(HloModule* module) { + auto& test_backend = backend(); + const auto& gpu_device_info = + test_backend.default_stream_executor()->GetDeviceDescription(); + TF_RETURN_IF_ERROR( + ScheduleGpuModule(module, /*pointer_size=*/8, gpu_device_info) + .status()); + return module; + } + + HloModuleConfig GetModuleConfig(absl::string_view fdo_profile) { + HloModuleConfig config; + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true); + debug_options.set_xla_gpu_lhs_enable_gpu_async_tracker(true); + config.set_debug_options(debug_options); + *config.mutable_fdo_profile() = fdo_profile; + return config; + } +}; + +TEST_F(GpuLatencyHidingSchedulerBaseTest, + GPUProfileStatisticsAggregatorDoesNotCountMissingNoops) { + GPUProfileStatisticsAggregator aggregator; + ProfileStatisticsAggregator::Statistics before_stats = aggregator.GetStats(); + + ASSERT_EQ(before_stats.missing_instructions.size(), 0); + ASSERT_EQ(before_stats.found_instructions_count, 0); + + absl::string_view kFdoProfile = ""; + absl::string_view kHloModule = R"( + HloModule m + + ENTRY main { + parameter0 = f32[] parameter(0) + parameter1 = f32[32] parameter(1) + const0 = f32[] constant(42) + bitcast0 = f32[2,16] bitcast(parameter1) + tuple0 = (f32[], f32[2,16]) tuple(parameter0, bitcast0) + ROOT _ = get-tuple-element(tuple0), index=0 + } + )"; + + auto config = GetModuleConfig(kFdoProfile); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + for (const HloInstruction* instr : + module->entry_computation()->instructions()) { + aggregator.HandleMissingInstructionCost(*instr); + + ProfileStatisticsAggregator::Statistics after_stats = aggregator.GetStats(); + EXPECT_EQ(after_stats.missing_instructions.size(), 0); + EXPECT_EQ(after_stats.found_instructions_count, 0); + } +} + +TEST_F(GpuLatencyHidingSchedulerBaseTest, + GPUProfileStatisticsAggregatorCountsMissingInstruction) { + GPUProfileStatisticsAggregator aggregator; + ProfileStatisticsAggregator::Statistics before_stats = aggregator.GetStats(); + + ASSERT_EQ(before_stats.missing_instructions.size(), 0); + ASSERT_EQ(before_stats.found_instructions_count, 0); + + absl::string_view kFdoProfile = R"pb( + costs { name: "dot0" cost_us: 100.0 } + )pb"; + absl::string_view kHloModule = R"( + HloModule m + + ENTRY main { + parameter0 = f32[] parameter(0) + parameter1 = f32[32] parameter(1) + const0 = f32[] constant(42) + add0 = f32[] add(parameter0, const0) + bitcast0 = f32[2,16] bitcast(parameter1) + tuple0 = (f32[], f32[2,16]) tuple(add0, bitcast0) + ROOT _ = get-tuple-element(tuple0), index=0 + } + )"; + + auto config = GetModuleConfig(kFdoProfile); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + for (const HloInstruction* instr : + module->entry_computation()->instructions()) { + aggregator.HandleMissingInstructionCost(*instr); + } + ProfileStatisticsAggregator::Statistics after_stats = aggregator.GetStats(); + EXPECT_EQ(after_stats.missing_instructions.size(), 1); + EXPECT_EQ((*after_stats.missing_instructions.begin())->opcode(), + HloOpcode::kAdd); + EXPECT_EQ(after_stats.found_instructions_count, 0); +} + +TEST_F(GpuLatencyHidingSchedulerBaseTest, + GPUProfileStatisticsAggregatorCountsMissingAsyncPairs) { + GPUProfileStatisticsAggregator aggregator; + ProfileStatisticsAggregator::Statistics before_stats = aggregator.GetStats(); + + ASSERT_EQ(before_stats.missing_instructions.size(), 0); + ASSERT_EQ(before_stats.found_instructions_count, 0); + + absl::string_view kFdoProfile = ""; + absl::string_view kHloModule = R"( + HloModule m + + reduce { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT _ = f32[] add(x, y) + } + + ENTRY main { + p0 = f32[] parameter(0) + p1 = f32[2] parameter(1) + ar_0 = f32[] all-reduce-start(p0), to_apply=reduce + ar_1 = f32[] all-reduce-done(ar_0) + rs_0 = ((f32[2]), f32[1]) reduce-scatter-start(p1), to_apply=reduce, dimensions={0} + rs_1 = f32[1] reduce-scatter-done(rs_0) + ag_0 = (f32[2], f32[4]) all-gather-start(p1), replica_groups={{0,1}}, dimensions={0} + ag_1 = f32[4] all-gather-done(ag_0) + ROOT _ = (f32[], f32[1], f32[4]) tuple(ar_1, rs_1, ag_1) + } + )"; + + auto config = GetModuleConfig(kFdoProfile); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + for (const HloInstruction* instr : + module->entry_computation()->instructions()) { + for (const HloInstruction* user : instr->users()) { + aggregator.HandleMissingInstructionLatency(*instr, *user); + } + } + ProfileStatisticsAggregator::Statistics after_stats = aggregator.GetStats(); + EXPECT_EQ(after_stats.found_instructions_count, 0); + EXPECT_EQ(after_stats.missing_instructions.size(), 3); + EXPECT_THAT( + after_stats.missing_instructions, + UnorderedElementsAre( + Property(&HloInstruction::opcode, HloOpcode::kAllReduceStart), + Property(&HloInstruction::opcode, HloOpcode::kAsyncStart), + Property(&HloInstruction::opcode, HloOpcode::kAllGatherStart))); +} + +TEST_F(GpuLatencyHidingSchedulerBaseTest, + ScheduleGpuModuleErrorsOutOnMissingInstrucitonsForAWhileLoopBody) { + absl::string_view kFdoProfile = R"pb( + costs { name: "dot0" cost_us: 100.0 } + )pb"; + absl::string_view kHloModule = R"( + HloModule m + + loop_body { + p = (u32[], f32[1]) parameter(0) + t0 = u32[] get-tuple-element(p), index=0 + t1 = f32[1] get-tuple-element(p), index=1 + add0 = f32[1] add(t1, t1) + ROOT _ = (u32[],f32[1]) tuple(t0,t1) + } + + loop_cond { + p1 = (u32[], f32[1]) parameter(0) + count = u32[] get-tuple-element(p1), index=0 + ub = u32[] constant(2) + ROOT _ = pred[] compare(count, ub), direction=LT + } + + ENTRY main { + p2 = f32[1] parameter(0) + ind = u32[] constant(1) + t = (u32[],f32[1]) tuple(ind,p2) + w = (u32[],f32[1]) while(t), body=loop_body, condition=loop_cond + ROOT _ = f32[1] get-tuple-element(w), index=1 + } + )"; + auto config = GetModuleConfig(kFdoProfile); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + EXPECT_THAT(ScheduleModule(module.get()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(GpuLatencyHidingSchedulerBaseTest, + ScheduleGpuModuleErrorsOutOnMissingInstrucitonsForAnEntryComputation) { + absl::string_view kFdoProfile = R"pb( + costs { name: "dot0" cost_us: 100.0 } + )pb"; + absl::string_view kHloModule = R"( + HloModule m + + ENTRY main { + p0 = f32[1] parameter(0) + ROOT add0 = f32[1] add(p0,p0) + } + )"; + auto config = GetModuleConfig(kFdoProfile); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + EXPECT_THAT(ScheduleModule(module.get()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F(GpuLatencyHidingSchedulerBaseTest, + ScheduleGpuModulePassesOnFullFDOProfile) { + absl::string_view kFdoProfile = R"pb( + costs { name: "add0" cost_us: 100.0 } + )pb"; + absl::string_view kHloModule = R"( + HloModule m + + ENTRY main { + p0 = f32[1] parameter(0) + ROOT add0 = f32[1] add(p0,p0) + } + )"; + auto config = GetModuleConfig(kFdoProfile); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + TF_EXPECT_OK(ScheduleModule(module.get())); +} + } // namespace } // namespace xla::gpu diff --git a/xla/service/profile_guided_latency_estimator.cc b/xla/service/profile_guided_latency_estimator.cc index 6e3810aa24c516..b17d762b01fae2 100644 --- a/xla/service/profile_guided_latency_estimator.cc +++ b/xla/service/profile_guided_latency_estimator.cc @@ -15,11 +15,14 @@ limitations under the License. #include "xla/service/profile_guided_latency_estimator.h" +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -29,6 +32,48 @@ limitations under the License. namespace xla { +namespace { + +// Small wrapper ensuring aggregator is provided and if it is, then it performs +// forwarding the instruction to an appropriate handler. +void HandleMissingInstructionCost(ProfileStatisticsAggregator* aggregator, + const HloInstruction* instruction) { + if (aggregator != nullptr) { + aggregator->HandleMissingInstructionCost(*instruction); + } +} + +// Small wrapper ensuring aggregator is provided and if it is, then it performs +// forwarding the instruction to an appropriate handler. +void HandleFoundInstructionCost(ProfileStatisticsAggregator* aggregator, + const HloInstruction* instruction) { + if (aggregator != nullptr) { + aggregator->HandleFoundInstructionCost(*instruction); + } +} + +// Small wrapper ensuring aggregator is provided and if it is, then it performs +// forwarding the from/to instruction pair to an appropriate handler. +void HandleMissingInstructionLatency(ProfileStatisticsAggregator* aggregator, + const HloGraphNode& from, + const HloGraphNode& to) { + if (aggregator != nullptr) { + aggregator->HandleMissingInstructionLatency(from.GetInstr(), to.GetInstr()); + } +} + +// Small wrapper ensuring aggregator is provided and if it is, then it performs +// forwarding the from/to instruction pair to an appropriate handler. +void HandleFoundInstructionLatency(ProfileStatisticsAggregator* aggregator, + const HloGraphNode& from, + const HloGraphNode& to) { + if (aggregator != nullptr) { + aggregator->HandleFoundInstructionLatency(from.GetInstr(), to.GetInstr()); + } +} + +} // namespace + LatencyEstimator::TimeCost ProfileGuidedLatencyEstimator::GetLatencyBetween( const HloGraphNode& from, const HloGraphNode& target) const { static constexpr HloGraphNode::TimeCost kLowLatency = 1.0; @@ -53,6 +98,7 @@ LatencyEstimator::TimeCost ProfileGuidedLatencyEstimator::GetLatencyBetween( VLOG(1) << "PGLE did NOT find wrapped instruction name or async start. From: " << from.GetInstr().name(); + HandleMissingInstructionLatency(aggregator_.get(), from, target); return latency_estimator_->GetLatencyBetween(from, target); } @@ -66,6 +112,7 @@ LatencyEstimator::TimeCost ProfileGuidedLatencyEstimator::GetLatencyBetween( if (it2 != it->second.latencies.end()) { VLOG(2) << "PGLE found latency between " << from.GetInstr().name() << " and " << target.GetInstr().name() << " in latency info"; + HandleFoundInstructionLatency(aggregator_.get(), from, target); return it2->second * CyclesPerMicrosecond(); } @@ -76,12 +123,14 @@ LatencyEstimator::TimeCost ProfileGuidedLatencyEstimator::GetLatencyBetween( VLOG(2) << "PGLE found latency for async op " << from.GetInstr().name() << " and (assumed)" << target.GetInstr().name() << " in instruction costs"; + HandleFoundInstructionLatency(aggregator_.get(), from, target); return *it->second.cost * CyclesPerMicrosecond(); } VLOG(1) << "PGLE did not find relevant profiling info for '" << from.GetInstr().name() << "', and '" << target.GetInstr().name() << "'."; + HandleMissingInstructionLatency(aggregator_.get(), from, target); return latency_estimator_->GetLatencyBetween(from, target); } @@ -95,17 +144,72 @@ LatencyEstimator::TimeCost ProfileGuidedLatencyEstimator::NodeCost( if (auto it = instr_map_.find(instr->name()); it != instr_map_.end() && it->second.cost.has_value()) { VLOG(2) << "PGLE found cost for: " << instr->name(); + HandleFoundInstructionCost(aggregator_.get(), instr); return *it->second.cost; } VLOG(1) << "PGLE missed cost for: " << instr->name(); + HandleMissingInstructionCost(aggregator_.get(), instr); return latency_estimator_->NodeCost(instr); } +ProfileStatisticsAggregator::Statistics +ProfileStatisticsAggregator::GetStats() { + return { + /*found_instructions_count=*/found_instructions_count_, + /*missing_instructions=*/missing_instructions_, + }; +} + +absl::Status ProfileGuidedLatencyEstimator::CheckAccuracy( + const HloModule& module) { + if (aggregator_ == nullptr) { + return absl::FailedPreconditionError( + "Failing because `aggregator_` was not provided when constructing " + "PGLE."); + } + + for (const auto& comp : module.computations()) { + // We only check profile application for while bodies and entry computation + // to avoid fine-grained exclusion of fusion computations, wrapped async + // computations, trivial to_apply computations (present in e.g. reductions) + // etc. + if (!comp->IsEntryComputation() && !comp->IsWhileBodyComputation()) { + continue; + } + for (const HloInstruction* instr : comp->MakeInstructionPostOrder()) { + NodeCost(instr); + HloGraphNode from(instr, /*original_position=*/-1); + for (const HloInstruction* user : instr->users()) { + HloGraphNode to(user, /*original_position=*/-1); + GetLatencyBetween(from, to); + } + } + } + ProfileStatisticsAggregator::Statistics stats = aggregator_->GetStats(); + size_t missing_instructions_count = stats.missing_instructions.size(); + if (missing_instructions_count > 0) { + LOG(ERROR) << "Found " << missing_instructions_count + << " instructions from the profile."; + LOG(ERROR) << "Missing " << missing_instructions_count + << " instructions from the profile."; + for (const HloInstruction* instr : stats.missing_instructions) { + LOG(ERROR) << " " << instr->name(); + } + return absl::InvalidArgumentError( + absl::StrCat("Found ", missing_instructions_count, + " missing instructions. Discarding the profile.")); + } + return absl::OkStatus(); +} + ProfileGuidedLatencyEstimator::ProfileGuidedLatencyEstimator( const SchedulerConfig& config, std::unique_ptr latency_estimator, - const tensorflow::profiler::ProfiledInstructionsProto& proto) - : config_(config), latency_estimator_(std::move(latency_estimator)) { + const tensorflow::profiler::ProfiledInstructionsProto& proto, + std::unique_ptr aggregator) + : config_(config), + latency_estimator_(std::move(latency_estimator)), + aggregator_(std::move(aggregator)) { const int cycles_per_microsecond = latency_estimator_->CyclesPerMicrosecond(); for (const auto& instr_cost : proto.costs()) { instr_map_[instr_cost.name()] = diff --git a/xla/service/profile_guided_latency_estimator.h b/xla/service/profile_guided_latency_estimator.h index a3b939ce77ec8d..40a4df59afd8f2 100644 --- a/xla/service/profile_guided_latency_estimator.h +++ b/xla/service/profile_guided_latency_estimator.h @@ -21,11 +21,49 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/latency_hiding_scheduler.h" #include "tsl/profiler/protobuf/profiled_instructions.pb.h" namespace xla { +// Helper class enabling gathering of statistics (such as missing instruction +// from the profile) for PGLE. +class ProfileStatisticsAggregator { + public: + struct Statistics { + int found_instructions_count; + absl::flat_hash_set& missing_instructions; + }; + + virtual ~ProfileStatisticsAggregator() = default; + + // Handler for the missing instruction cost. + virtual void HandleMissingInstructionCost( + const HloInstruction& instruction) = 0; + + // Handler for found instruction cost. + virtual void HandleFoundInstructionCost( + const HloInstruction& instruction) = 0; + + // Handler for the missing latency info between `from` and `to`. + virtual void HandleMissingInstructionLatency(const HloInstruction& from, + const HloInstruction& to) = 0; + + // Handler for found latency info between `from` and `to`. + virtual void HandleFoundInstructionLatency(const HloInstruction& from, + const HloInstruction& to) = 0; + + // Returns gathered statistics summary. + Statistics GetStats(); + + protected: + absl::flat_hash_set missing_instructions_; + int found_instructions_count_ = 0; +}; + // Implementation of LatencyEstimator using a profile to estimate HLO cost and // latencies between instructions. If a cost is not known, it will forward to // an underlying estimator. @@ -34,7 +72,8 @@ class ProfileGuidedLatencyEstimator : public LatencyEstimator { ProfileGuidedLatencyEstimator( const SchedulerConfig& config, std::unique_ptr latency_estimator, - const tensorflow::profiler::ProfiledInstructionsProto& proto); + const tensorflow::profiler::ProfiledInstructionsProto& proto, + std::unique_ptr aggregator = nullptr); TimeCost GetLatencyBetween(const HloGraphNode& from, const HloGraphNode& target) const override; @@ -43,6 +82,14 @@ class ProfileGuidedLatencyEstimator : public LatencyEstimator { return latency_estimator_->CyclesPerMicrosecond(); } + // Checks whether `module` has all the respective instructions present in the + // profile grabbed by this object. + // + // Returns absl::OkStatus if accuracy check passes, + // `absl::InvalidArgumentError` does not pass and + // `absl::FailedPreconditionError` if `aggregator_` is not provided. + absl::Status CheckAccuracy(const HloModule& module); + private: const SchedulerConfig config_; std::unique_ptr latency_estimator_; @@ -54,6 +101,8 @@ class ProfileGuidedLatencyEstimator : public LatencyEstimator { absl::flat_hash_map latencies; }; absl::flat_hash_map instr_map_; + // Aggregator gathering data about missed/found instructions. + std::unique_ptr aggregator_; }; } // namespace xla diff --git a/xla/service/profile_guided_latency_estimator_test.cc b/xla/service/profile_guided_latency_estimator_test.cc index fe0a012d91f8c8..ff2d766b8b07c2 100644 --- a/xla/service/profile_guided_latency_estimator_test.cc +++ b/xla/service/profile_guided_latency_estimator_test.cc @@ -22,12 +22,17 @@ limitations under the License. #include #include +#include +#include #include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/latency_hiding_scheduler.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/protobuf/profiled_instructions.pb.h" @@ -35,6 +40,8 @@ namespace xla { namespace { +using ::tsl::testing::StatusIs; + int GetIndex(absl::Span instruction_sequence, absl::string_view hlo_name) { return absl::c_find_if(instruction_sequence, @@ -149,8 +156,8 @@ ENTRY entry { } } - // cp2s should come first since the latency between cp2s->cp2d is double that - // of cp1s->cp1d + // cp2s should come first since the latency between cp2s->cp2d is double + // that of cp1s->cp1d EXPECT_LT(GetIndex(new_instruction_sequence, "cp2s"), GetIndex(new_instruction_sequence, "cp1s")); } @@ -264,4 +271,30 @@ ENTRY entry { EXPECT_EQ(recv_latency, 100.0); } +TEST_F(ProfileGuidedLatencyEstimatorTest, + ProfileGuidedLatencyEstimatorCheckAccuracyFailsIfMissingAggregator) { + std::string kFdoProfile = ""; + absl::string_view kHloModule = R"( + HloModule module + + ENTRY main { + p0 = f32[1] parameter(0) + ROOT add0 = f32[1] add(p0,p0) + } +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, + ParseAndReturnVerifiedModule(kHloModule)); + tensorflow::profiler::ProfiledInstructionsProto fdo_profile; + ASSERT_TRUE( + tsl::protobuf::TextFormat::ParseFromString(kFdoProfile, &fdo_profile)); + + auto sched_config = GetDefaultSchedConfig(); + auto latency_estimator = std::make_unique( + sched_config, std::make_unique(), + fdo_profile); + EXPECT_THAT(latency_estimator->CheckAccuracy(*hlo_module), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + } // namespace xla From c914e7dd4f425b38b2ce2bb32004679d417121d6 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Mon, 22 Jul 2024 08:09:53 -0700 Subject: [PATCH 059/376] [XLA:GPU] Remove old IsProfileApplicable method. PiperOrigin-RevId: 654754841 --- xla/service/BUILD | 1 - xla/service/gpu/gpu_hlo_schedule.cc | 42 ------------------- xla/service/gpu/gpu_hlo_schedule.h | 9 ----- xla/service/gpu/gpu_hlo_schedule_test.cc | 51 ------------------------ 4 files changed, 103 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index c2d9afa0914f5d..28e94ff7b0e022 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1526,7 +1526,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", diff --git a/xla/service/gpu/gpu_hlo_schedule.cc b/xla/service/gpu/gpu_hlo_schedule.cc index c856cdf468c3f6..4f0d3fce842223 100644 --- a/xla/service/gpu/gpu_hlo_schedule.cc +++ b/xla/service/gpu/gpu_hlo_schedule.cc @@ -417,42 +417,6 @@ std::optional ReadPGLEProfile( } } // end namespace -absl::Status IsProfileApplicable( - const HloModule* module, - const tensorflow::profiler::ProfiledInstructionsProto& profile) { - absl::flat_hash_set all_instruction_names; - for (HloComputation* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* instr : comp->instructions()) { - all_instruction_names.insert(instr->name()); - } - } - - std::vector missing_costs_names; - for (const auto& cost : profile.costs()) { - if (!all_instruction_names.contains(cost.name())) { - missing_costs_names.push_back(cost.name()); - } - } - std::vector missing_latency_names; - for (const auto& latency : profile.latencies()) { - if (!all_instruction_names.contains(latency.source())) { - missing_latency_names.push_back(latency.source()); - } - - if (!all_instruction_names.contains(latency.target())) { - missing_latency_names.push_back(latency.target()); - } - } - if (!(missing_costs_names.empty() && missing_latency_names.empty())) { - return absl::InvalidArgumentError( - absl::StrFormat("\nMissing costs: %s;\nMissing latencies: %s", - absl::StrJoin(missing_costs_names, ", "), - absl::StrJoin(missing_latency_names, ", "))); - } - - return absl::OkStatus(); -} - static int64_t GetSchedulerMemoryLimit( const HloModule* module, const se::DeviceDescription& gpu_device_info, int pointer_size); @@ -514,12 +478,6 @@ absl::StatusOr ScheduleGpuModule( LOG(INFO) << "Found profile, using profile guided latency estimator. Profile:\n" << profile->DebugString(); - absl::Status s = IsProfileApplicable(module, profile.value()); - if (!s.ok()) { - LOG(INFO) << "PGLE profile may not applicable to the module, but will " - "still be used : " - << s.message(); - } TF_RETURN_IF_ERROR(pg_latency_estimator->CheckAccuracy(*module)); latency_estimator = std::move(pg_latency_estimator); } else if (enable_analytical_latency_estimator) { diff --git a/xla/service/gpu/gpu_hlo_schedule.h b/xla/service/gpu/gpu_hlo_schedule.h index 7263eff68eaa13..b71226c20710a9 100644 --- a/xla/service/gpu/gpu_hlo_schedule.h +++ b/xla/service/gpu/gpu_hlo_schedule.h @@ -18,25 +18,16 @@ limitations under the License. #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_schedule.h" -#include "xla/shape.h" #include "xla/stream_executor/device_description.h" #include "tsl/profiler/protobuf/profiled_instructions.pb.h" namespace xla { namespace gpu { -// Returns `absl::OkStatus` if every instruction in the profile is present in -// the module. `absl::InvalidArgumentError` with missing culprit costs/latencies -// otherwise. -absl::Status IsProfileApplicable( - const HloModule* module, - const tensorflow::profiler::ProfiledInstructionsProto& profile); - struct ScheduleMetadata { int64_t scheduler_mem_limit; }; diff --git a/xla/service/gpu/gpu_hlo_schedule_test.cc b/xla/service/gpu/gpu_hlo_schedule_test.cc index 16f0332e07e3ba..0304f358d4b132 100644 --- a/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -492,57 +492,6 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModel) { } } -TEST_F(GpuHloScheduleTest, - ProfileGuidedCostModelApplicabilityListsMissingCostsAndLatencies) { - const char* hlo_text = R"( - HloModule AsyncAR - apply_op { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT apply_op = f32[] add(x, y) - } - - ENTRY ar { - p0 = f32[32] parameter(0) - p1 = f32[32, 32] parameter(1) - p2 = f32[32, 32] parameter(2) - p3 = f32[32] parameter(3) - - dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" - ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op - ar-done = f32[32] all-reduce-done(ar-start) - ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op - ar-done1 = f32[32] all-reduce-done(ar-start1) - - ROOT t = (f32[32], f32[32], f32[32,32]) tuple(ar-done, ar-done1, dot0) - })"; - - const std::string ar_long_latency_proto_text = R"pb( - costs { name: "dot0" cost_us: 100.0 } - costs { name: "dot1" cost_us: 100.0 } - costs { name: "add0" cost_us: 10.0 } - costs { name: "ar-start" cost_us: 10.0 } - costs { name: "ar-start-2" cost_us: 10.0 } - )pb"; - - tensorflow::profiler::ProfiledInstructionsProto profile; - ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( - ar_long_latency_proto_text, &profile)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - ParseAndReturnVerifiedModule( - hlo_text, - GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, - /*enable_gpu_async_tracker=*/true, - /*fdo_profile=*/ar_long_latency_proto_text))); - - absl::Status result = IsProfileApplicable(module.get(), profile); - EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(result.message(), HasSubstr("add0")); - EXPECT_THAT(result.message(), HasSubstr("dot1")); - EXPECT_THAT(result.message(), HasSubstr("ar-start-2")); -} - TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelWithRematData) { const char* hlo_text = R"( HloModule AsyncAR From a8c6d33ef8ffb9115f1aae985678d1249d3c1e36 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jul 2024 08:55:07 -0700 Subject: [PATCH 060/376] [XLA:GPU] Move ConvertLayout method inside lowerSharedToSparseMeta PiperOrigin-RevId: 654771403 --- .../gpu/fusions/triton/sparse_extensions.cc | 145 +++++++++--------- 1 file changed, 69 insertions(+), 76 deletions(-) diff --git a/xla/service/gpu/fusions/triton/sparse_extensions.cc b/xla/service/gpu/fusions/triton/sparse_extensions.cc index 9337a33f446a3b..037631975e3486 100644 --- a/xla/service/gpu/fusions/triton/sparse_extensions.cc +++ b/xla/service/gpu/fusions/triton/sparse_extensions.cc @@ -353,79 +353,7 @@ class SparseBlockedToMMAPass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseBlockedToMMAPass) }; -namespace SharedToSparseDotOperand { - -Value convertLayout(ConversionPatternRewriter &rewriter, Location loc, - Value tensor, - triton::gpu::SparseDotMetaEncodingAttr sparseEncoding, - const SharedMemoryObject &smemObj, - const LLVMTypeConverter *typeConverter, Value thread) { - constexpr int kThreadsPerWarp = 32; - // Each 16x16 original sparse matrix tile requires 16 metadata values of - // 16-bit size, where the first thread (T0) in each 4-thread group holds two - // such values in a register (32-bit). - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-matrix-storage - constexpr int kTileSize = 16; - constexpr int kThreadsInGroup = 4; - constexpr int kMetadataElementsPerPackedValue = 8; // 8 x 2-bit = 16-bit - constexpr int kMetadataLineOffset = kThreadsPerWarp / kThreadsInGroup; - - // Calculate tile size as number of mask elements (4xi4). - NvidiaMmaEncodingAttr mmaLayout = - cast(sparseEncoding.getParent()); - SmallVector warpsPerCTA = mmaLayout.getWarpsPerCTA(); - SmallVector shapePerCTATile = { - kTileSize * warpsPerCTA[0], kTileSize / kMetadataElementsPerPackedValue}; - Value strideM = smemObj.strides[0]; - Value strideK = smemObj.strides[1]; - - // Calculate offset in the tile for the current thread. - Value threadsPerWarp = i32_val(kThreadsPerWarp); - Value warpId = udiv(thread, threadsPerWarp); - Value warpGroupId; - if (mmaLayout.isHopper()) { - warpGroupId = urem(warpId, i32_val(warpsPerCTA[0])); - } else { - assert(mmaLayout.isAmpere()); - warpGroupId = udiv(warpId, i32_val(warpsPerCTA[1])); - } - Value laneId = urem(thread, threadsPerWarp); - Value laneGroupId = udiv(laneId, i32_val(kThreadsInGroup)); - Value columnId = urem(laneId, i32_val(shapePerCTATile[1])); - Value rowId = add(mul(warpGroupId, i32_val(kTileSize)), laneGroupId); - - // Calculate number of tile repetitions. - auto shape = cast(tensor.getType()).getShape(); - int repM = shape[0] / shapePerCTATile[0]; - int repK = shape[1] / shapePerCTATile[1]; - assert(repM > 0 && repK > 0); - - // Load sparse metadata from shared memory. - MLIRContext *ctx = tensor.getContext(); - Type ptrTy = ptr_ty(ctx, 3); - Value base = gep(ptrTy, i16_ty, smemObj.base, i32_val(0)); - SmallVector values; - - for (int k = 0; k < repK; ++k) { - for (int m = 0; m < repM; ++m) { - Value row = add(rowId, i32_val(m * shapePerCTATile[0])); - Value column = add(columnId, i32_val(k * shapePerCTATile[1])); - Value offset1 = add(mul(row, strideM), mul(column, strideK)); - Value offset2 = add(offset1, mul(i32_val(kMetadataLineOffset), strideM)); - Value lower = load(i16_ty, gep(ptrTy, i16_ty, base, offset1)); - Value upper = load(i16_ty, gep(ptrTy, i16_ty, base, offset2)); - values.push_back(lower); - values.push_back(upper); - } - } - - // Pack resulting values as LLVM struct. - Type structTy = struct_ty(SmallVector(values.size(), i16_ty)); - return packLLElements(loc, typeConverter, values, rewriter, structTy); -} -} // namespace SharedToSparseDotOperand - -struct SparseLocalLoadToLLVM +class SparseLocalLoadToLLVM : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern< @@ -450,16 +378,81 @@ struct SparseLocalLoadToLLVM LogicalResult lowerSharedToSparseMeta( triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + constexpr int kThreadsPerWarp = 32; + // Each 16x16 original sparse matrix tile requires 16 metadata values of + // 16-bit size, where the first thread (T0) in each 4-thread group holds two + // such values in a register (32-bit). + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-matrix-storage + constexpr int kTileSize = 16; + constexpr int kThreadsInGroup = 4; + constexpr int kMetadataElementsPerPackedValue = 8; // 8 x 2-bit = 16-bit + constexpr int kMetadataLineOffset = kThreadsPerWarp / kThreadsInGroup; + auto loc = op.getLoc(); + Value tensor = op.getSrc(); auto sparseEncoding = cast( cast(op.getResult().getType()).getEncoding()); auto llvmElemTy = getTypeConverter()->convertType( cast(op.getSrc().getType()).getElementType()); auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); - Value res = SharedToSparseDotOperand::convertLayout( - rewriter, loc, op.getSrc(), sparseEncoding, smemObj, getTypeConverter(), - getThreadId(rewriter, loc)); + + // Calculate tile size as number of mask elements (4xi4). + NvidiaMmaEncodingAttr mmaLayout = + cast(sparseEncoding.getParent()); + SmallVector warpsPerCTA = mmaLayout.getWarpsPerCTA(); + SmallVector shapePerCTATile = { + kTileSize * warpsPerCTA[0], + kTileSize / kMetadataElementsPerPackedValue}; + Value strideM = smemObj.strides[0]; + Value strideK = smemObj.strides[1]; + + // Calculate offset in the tile for the current thread. + Value threadsPerWarp = i32_val(kThreadsPerWarp); + Value thread = getThreadId(rewriter, loc); + Value warpId = udiv(thread, threadsPerWarp); + Value warpGroupId; + if (mmaLayout.isHopper()) { + warpGroupId = urem(warpId, i32_val(warpsPerCTA[0])); + } else { + assert(mmaLayout.isAmpere()); + warpGroupId = udiv(warpId, i32_val(warpsPerCTA[1])); + } + Value laneId = urem(thread, threadsPerWarp); + Value laneGroupId = udiv(laneId, i32_val(kThreadsInGroup)); + Value columnId = urem(laneId, i32_val(shapePerCTATile[1])); + Value rowId = add(mul(warpGroupId, i32_val(kTileSize)), laneGroupId); + + // Calculate number of tile repetitions. + auto shape = cast(tensor.getType()).getShape(); + int repM = shape[0] / shapePerCTATile[0]; + int repK = shape[1] / shapePerCTATile[1]; + assert(repM > 0 && repK > 0); + + // Load sparse metadata from shared memory. + MLIRContext *ctx = tensor.getContext(); + Type ptrTy = ptr_ty(ctx, 3); + Value base = gep(ptrTy, i16_ty, smemObj.base, i32_val(0)); + SmallVector values; + + for (int k = 0; k < repK; ++k) { + for (int m = 0; m < repM; ++m) { + Value row = add(rowId, i32_val(m * shapePerCTATile[0])); + Value column = add(columnId, i32_val(k * shapePerCTATile[1])); + Value offset1 = add(mul(row, strideM), mul(column, strideK)); + Value offset2 = + add(offset1, mul(i32_val(kMetadataLineOffset), strideM)); + Value lower = load(i16_ty, gep(ptrTy, i16_ty, base, offset1)); + Value upper = load(i16_ty, gep(ptrTy, i16_ty, base, offset2)); + values.push_back(lower); + values.push_back(upper); + } + } + + // Pack resulting values as LLVM struct. + Type structTy = struct_ty(SmallVector(values.size(), i16_ty)); + Value res = + packLLElements(loc, getTypeConverter(), values, rewriter, structTy); rewriter.replaceOp(op, res); return success(); From b80c19d8b5488a93c0b72e223be968e519b61aca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jul 2024 09:41:05 -0700 Subject: [PATCH 061/376] [XLA] Cleaning up algebraic simplifier headers as suggested by clang-tidy PiperOrigin-RevId: 654789219 --- xla/service/BUILD | 12 +++++++++--- xla/service/algebraic_simplifier.cc | 1 - xla/service/algebraic_simplifier.h | 13 ++++++++++--- xla/service/algebraic_simplifier_test.cc | 15 +++++++++++++-- 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 28e94ff7b0e022..8775d81790d1a0 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -2712,7 +2712,6 @@ cc_library( ":shape_inference", "//xla:comparison_util", "//xla:literal", - "//xla:literal_comparison", "//xla:literal_util", "//xla:permutation_util", "//xla:shape_util", @@ -2728,6 +2727,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", @@ -2768,22 +2768,28 @@ xla_cc_test( ":hlo_creation_utils", ":hlo_parser", ":hlo_pass", - ":hlo_pass_pipeline", ":host_memory_offload_annotations_hdr", ":layout_assignment", ":pattern_matcher", ":pattern_matcher_gmock", ":shape_inference", + "//xla:comparison_util", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", - "//xla:types", + "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index 7e7a830b1cfa51..c0c4fe8ba5181b 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -52,7 +52,6 @@ limitations under the License. #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/literal_comparison.h" #include "xla/literal_util.h" #include "xla/overflow_util.h" #include "xla/permutation_util.h" diff --git a/xla/service/algebraic_simplifier.h b/xla/service/algebraic_simplifier.h index 37ac87b621a54d..1d1134998c709a 100644 --- a/xla/service/algebraic_simplifier.h +++ b/xla/service/algebraic_simplifier.h @@ -16,22 +16,29 @@ limitations under the License. #ifndef XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ #define XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ -#include #include #include #include #include #include -#include #include #include -#include "absl/container/inlined_vector.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/util.h" namespace xla { diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index ccc8415c9dd53c..1ba5bde63420fc 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/service/algebraic_simplifier.h" +#include +#include +#include +#include #include #include #include @@ -23,10 +27,16 @@ limitations under the License. #include #include +#include +#include "absl/algorithm/container.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -34,20 +44,21 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_pass_fix.h" -#include "xla/service/hlo_pass_pipeline.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/layout_assignment.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/shape_inference.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/types.h" +#include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" From 09ac06a2980b85088d28c9a10f9353d68b6695ed Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 22 Jul 2024 11:48:58 -0700 Subject: [PATCH 062/376] Remove now unused GpuTimer::Create functions. PiperOrigin-RevId: 654841530 --- xla/stream_executor/gpu/gpu_timer.cc | 35 ---------------------------- xla/stream_executor/gpu/gpu_timer.h | 3 --- 2 files changed, 38 deletions(-) diff --git a/xla/stream_executor/gpu/gpu_timer.cc b/xla/stream_executor/gpu/gpu_timer.cc index 1fa2f2c4f4bc04..325df3ec083243 100644 --- a/xla/stream_executor/gpu/gpu_timer.cc +++ b/xla/stream_executor/gpu/gpu_timer.cc @@ -99,41 +99,6 @@ absl::Status CreateGpuTimerParts(Stream* real_stream, bool use_delay_kernel, } } // namespace -/*deprecated*/ /*static*/ absl::StatusOr GpuTimer::Create( - GpuStream* stream) { - // This deprecated factory does not launch the delay kernel and may lead to - // reduced measurement accuracy. - GpuExecutor* parent = stream->parent(); - GpuContext* context = parent->gpu_context(); - GpuEventHandle start_event; - TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &start_event, - GpuDriver::EventFlags::kDefault)); - GpuEventHandle stop_event; - TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &stop_event, - GpuDriver::EventFlags::kDefault)); - CHECK(start_event != nullptr && stop_event != nullptr); - TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent->gpu_context(), start_event, - stream->gpu_stream())); - return absl::StatusOr{absl::in_place, parent, start_event, - stop_event, stream}; -} - -/*static*/ absl::StatusOr GpuTimer::Create(Stream* real_stream, - bool use_delay_kernel) { - GpuExecutor* parent = nullptr; - GpuEventHandle start_event = nullptr; - GpuEventHandle stop_event = nullptr; - GpuSemaphore semaphore{}; - TF_RETURN_IF_ERROR(CreateGpuTimerParts(real_stream, use_delay_kernel, parent, - start_event, stop_event, semaphore)); - return absl::StatusOr{absl::in_place, - parent, - start_event, - stop_event, - AsGpuStream(real_stream), - std::move(semaphore)}; -} - absl::StatusOr> GpuTimer::CreateEventBasedTimer(Stream* stream, bool use_delay_kernel) { GpuExecutor* parent = nullptr; diff --git a/xla/stream_executor/gpu/gpu_timer.h b/xla/stream_executor/gpu/gpu_timer.h index ea8eaa852960f8..3cfd3cfc34efb0 100644 --- a/xla/stream_executor/gpu/gpu_timer.h +++ b/xla/stream_executor/gpu/gpu_timer.h @@ -48,11 +48,8 @@ class GpuStream; // to be measured more accurately. class GpuTimer : public EventBasedTimer { public: - static absl::StatusOr Create(Stream* stream, bool use_delay_kernel); static absl::StatusOr> CreateEventBasedTimer( Stream* stream, bool use_delay_kernel); - [[deprecated("Pass Stream* not GpuStream*")]] static absl::StatusOr - Create(GpuStream* stream); explicit GpuTimer(GpuExecutor* parent, GpuEventHandle start_event, GpuEventHandle stop_event, GpuStream* stream, From 52b818467c70962b89f93b94f255a2be6c590852 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 22 Jul 2024 12:42:13 -0700 Subject: [PATCH 063/376] Fix crash in AllReduceBlueConnect when multiple partitions are used. Also, the pass now only runs when the all-reduce op has specific values for CollectiveOpGroupMode: kCrossReplica and kFlattenedID. Previously, the pass crashed with any mode other than kCrossReplica. I'm not sure when the two other modes, kCrossPartition and kCrossReplicaAndPartition, are used in JAX programs, and am unsure how to create HLO which uses these modes, so I decided not to support them for now. PiperOrigin-RevId: 654858771 --- xla/service/gpu/BUILD | 3 + xla/service/gpu/all_reduce_blueconnect.cc | 62 ++++++++++-- .../gpu/all_reduce_blueconnect_test.cc | 99 +++++++++++++++++-- 3 files changed, 147 insertions(+), 17 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 4fd126613b4791..2d6fb98e359e97 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3762,7 +3762,9 @@ cc_library( "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:collective_ops_utils", "//xla/service:computation_placer_hdr", + "//xla/service:global_device_id", "//xla/service:hlo_creation_utils", "//xla/service:hlo_pass", "@com_google_absl//absl/algorithm:container", @@ -3783,6 +3785,7 @@ xla_cc_test( deps = [ ":all_reduce_blueconnect", "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:computation_placer_hdr", "//xla/service:pattern_matcher", diff --git a/xla/service/gpu/all_reduce_blueconnect.cc b/xla/service/gpu/all_reduce_blueconnect.cc index 9d259f5f64a861..2e75ffaf55b12a 100644 --- a/xla/service/gpu/all_reduce_blueconnect.cc +++ b/xla/service/gpu/all_reduce_blueconnect.cc @@ -33,7 +33,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" +#include "xla/service/global_device_id.h" #include "xla/service/hlo_creation_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -66,18 +68,50 @@ struct DecomposedReplicaGroups { std::vector new_all_reduce_groups; }; +// Returns the global device id for the given replica id. Returns nullopt if +// if the replica id can refer to multiple devices, or if the pass does not +// support the CollectiveOpGroupMode. +std::optional TryConvertingReplicaIdToDeviceId( + int64_t replica_id, const DeviceAssignment& device_assignment, + CollectiveOpGroupMode collective_group_mode) { + if (collective_group_mode == CollectiveOpGroupMode::kCrossReplica) { + if (device_assignment.computation_count() != 1) { + // If there are multiple partitions, the replica_id may refer to multiple + // devices on different partitions. + return std::nullopt; + } + return GlobalDeviceId{device_assignment(replica_id, /*computation_id=*/0)}; + } else if (collective_group_mode == CollectiveOpGroupMode::kFlattenedID) { + int partition_count = device_assignment.computation_count(); + int64_t actual_replica_id = replica_id / partition_count; + int64_t partition_id = replica_id % partition_count; + return GlobalDeviceId{device_assignment(actual_replica_id, partition_id)}; + } + + // kCrossPartition and kCrossReplicaAndPartition are unsupported. + VLOG(1) << "Skip AllReduceBlueConnect because of unsupported " + "CollectiveOpGroupMode " + << CollectiveOpGroupModeToString(collective_group_mode); + return std::nullopt; +} + absl::StatusOr> TryDecomposeReplicaGroup( const ReplicaGroup& replica_group, - const DeviceAssignment& device_assignment, size_t num_devices_per_host) { + const DeviceAssignment& device_assignment, size_t num_devices_per_host, + CollectiveOpGroupMode collective_group_mode) { int group_size = replica_group.replica_ids_size(); TF_RET_CHECK(group_size > 0); absl::btree_map> replica_ids_by_host; for (int64_t replica_id : replica_group.replica_ids()) { - int device_id = device_assignment(replica_id, /*computation_id=*/0); - TF_RET_CHECK(device_id >= 0); + std::optional device_id = TryConvertingReplicaIdToDeviceId( + replica_id, device_assignment, collective_group_mode); + if (!device_id.has_value()) { + return {std::nullopt}; + } + TF_RET_CHECK(*device_id >= 0); // We assume that devices are ordered by host. - int host_id = device_id / num_devices_per_host; + int host_id = device_id->value() / num_devices_per_host; replica_ids_by_host[host_id].push_back(replica_id); } @@ -133,6 +167,11 @@ TryDecomposeReplicaGroups(const HloAllReduceInstruction& all_reduce, replica_groups = absl::MakeSpan(&all_replicas, 1); } + TF_ASSIGN_OR_RETURN( + CollectiveOpGroupMode collective_op_group_mode, + GetCollectiveOpGroupMode(all_reduce.channel_id().has_value(), + all_reduce.use_global_device_ids())); + std::vector scatter_gather_groups; std::vector new_all_reduce_groups; @@ -141,7 +180,8 @@ TryDecomposeReplicaGroups(const HloAllReduceInstruction& all_reduce, TF_ASSIGN_OR_RETURN( std::optional decomposed_groups, TryDecomposeReplicaGroup(replica_group, device_assignment, - num_devices_per_host)); + num_devices_per_host, + collective_op_group_mode)); if (!decomposed_groups) return {std::nullopt}; @@ -233,11 +273,19 @@ absl::StatusOr TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce, Shape reduce_scatter_shape = ShapeUtil::MakeMaybeTupleShape(scattered_shapes); + int64_t next_channel_id = hlo_query::NextChannelId(*computation.parent()); + auto get_channel_id = [&]() -> std::optional { + if (all_reduce->channel_id().has_value()) { + return next_channel_id++; + } + return std::nullopt; + }; + HloInstruction* reduce_scatter = computation.AddInstruction(HloInstruction::CreateReduceScatter( reduce_scatter_shape, flat_operands, all_reduce->to_apply(), CollectiveDeviceList(decomposed_groups->scatter_gather_groups), - /*constrain_layout=*/false, all_reduce->channel_id(), + /*constrain_layout=*/false, get_channel_id(), all_reduce->use_global_device_ids(), /*scatter_dimension=*/0)); @@ -255,7 +303,7 @@ absl::StatusOr TryDecomposeAllReduce(HloAllReduceInstruction* all_reduce, GetOutputs(*new_all_reduce), /*all_gather_dimension=*/0, CollectiveDeviceList(decomposed_groups->scatter_gather_groups), - /*constrain_layout=*/false, all_reduce->channel_id(), + /*constrain_layout=*/false, get_channel_id(), all_reduce->use_global_device_ids())); // Bitcast back to the original shapes and replace all-reduce with decomposed diff --git a/xla/service/gpu/all_reduce_blueconnect_test.cc b/xla/service/gpu/all_reduce_blueconnect_test.cc index b04fa92733747b..a6a66c5189af42 100644 --- a/xla/service/gpu/all_reduce_blueconnect_test.cc +++ b/xla/service/gpu/all_reduce_blueconnect_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include +#include +#include #include #include @@ -33,6 +35,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -44,12 +47,25 @@ namespace m = ::xla::match; using AllReduceBlueConnectTest = HloTestBase; -void SetModuleConfig(HloModule& module, size_t replica_count) { - DeviceAssignment device_assignment(replica_count, /*computation_count=*/1); +HloPredicate MatchChannelId(std::optional channel_id) { + return [channel_id](const HloInstruction* instruction) { + return instruction->channel_id() == channel_id; + }; +} + +void SetModuleConfig(HloModuleConfig* module_config, size_t replica_count, + size_t partition_count = 1) { + DeviceAssignment device_assignment(replica_count, + /*computation_count=*/partition_count); device_assignment.FillIota(0); - auto& module_config = module.mutable_config(); - module_config.set_replica_count(replica_count); - module_config.set_static_device_assignment(device_assignment); + module_config->set_replica_count(replica_count); + module_config->set_num_partitions(partition_count); + module_config->set_static_device_assignment(device_assignment); +} + +void SetModuleConfig(HloModule& module, size_t replica_count, + size_t partition_count = 1) { + SetModuleConfig(&module.mutable_config(), replica_count, partition_count); } TEST_F(AllReduceBlueConnectTest, OneStage) { @@ -81,15 +97,18 @@ ENTRY %comp { // clang-format on auto bitcast = m::Bitcast(m::Parameter(0)).WithShape(F32, {16}); - auto reduce_scatter = - m::ReduceScatter(bitcast).WithShape(F32, {4}).WithReplicaGroups( - scatter_gather_groups); + auto reduce_scatter = m::ReduceScatter(bitcast) + .WithShape(F32, {4}) + .WithReplicaGroups(scatter_gather_groups) + .WithPredicate(MatchChannelId(std::nullopt)); auto all_reduce = m::AllReduce(reduce_scatter) .WithShape(F32, {4}) - .WithReplicaGroups(new_all_reduce_groups); + .WithReplicaGroups(new_all_reduce_groups) + .WithPredicate(MatchChannelId(std::nullopt)); auto all_gather = m::AllGather(all_reduce) .WithShape(F32, {16}) - .WithReplicaGroups(scatter_gather_groups); + .WithReplicaGroups(scatter_gather_groups) + .WithPredicate(MatchChannelId(std::nullopt)); EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Bitcast(all_gather).WithShape(F32, {4, 4}))); } @@ -199,6 +218,41 @@ ENTRY %comp { GmockMatch(m::Tuple(bitcast2, bitcast3))); } +TEST_F(AllReduceBlueConnectTest, MultiplePartitionsFilecheck) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +%add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %comp { + p0 = f32[8,8] parameter(0) + ROOT crs = f32[8,8] all-reduce(p0), channel_id=1, + replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=add +})"; + HloModuleConfig module_config; + SetModuleConfig(&module_config, /*replica_count=*/1, /*partition_count=*/8); + + AllReduceBlueConnect pass(/*num_devices_per_host=*/4); + // Note: When matching strings like "replica_groups={{0,1,2,3}}", FileCheck + // interprets the string inside the double braces as regex. So to match such + // strings, we use "replica_groups={{..0,1,2,3..}}", where the dots match the + // opening and closing braces. + RunAndFilecheckHloRewrite(hlo_string, std::move(pass), R"( + CHECK: %p0 = f32[8,8]{1,0} parameter(0) + CHECK-NEXT: [[bitcast:%[^ ]+]] = f32[64]{0} bitcast(%p0) + CHECK-NEXT: [[reduce_scatter:%[^ ]+]] = f32[16]{0} reduce-scatter([[bitcast]]), channel_id=2, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, use_global_device_ids=true, dimensions={0}, to_apply=%add + CHECK-NEXT: [[all_reduce:%[^ ]+]] = f32[16]{0} all-reduce([[reduce_scatter]]), channel_id=1, replica_groups={{..0,4.,.1,5.,.2,6.,.3,7..}}, use_global_device_ids=true, to_apply=%add + CHECK-NEXT: [[all_gather:%[^ ]+]] = f32[64]{0} all-gather([[all_reduce]]), channel_id=3, replica_groups={{..0,1,2,3.,.4,5,6,7..}}, dimensions={0}, use_global_device_ids=true + CHECK-NEXT: ROOT [[output:%[^ ]+]] = f32[8,8]{1,0} bitcast([[all_gather]]) +} +)", + /*after_pass_checks=*/nullptr, &module_config); +} + TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesWithinReplicaGroup) { constexpr absl::string_view hlo_string = R"( HloModule module @@ -331,5 +385,30 @@ ENTRY %comp { {}, absl::MakeSpan(expected_succs)))); } +TEST_F(AllReduceBlueConnectTest, ReduceScatterUnchanged) { + // Tests that this pass does not affect reduce-scatter. In principle, the + // BlueConnect algorithm could be applied to reduce-scatter, but for now it + // doesn't. + constexpr absl::string_view hlo_string = R"( +HloModule module + +%add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %comp { + p0 = f32[8,4] parameter(0) + ROOT crs = f32[1,4] reduce-scatter(p0), dimensions={0}, to_apply=add +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + SetModuleConfig(*module, /*replica_count=*/8); + + AllReduceBlueConnect pass(/*num_devices_per_host=*/4); + EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false)); +} + } // namespace } // namespace xla From 5439475334c1eb641530eb0092f58ae92e7f8228 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Mon, 22 Jul 2024 12:44:26 -0700 Subject: [PATCH 064/376] Update Shardy commit hash PiperOrigin-RevId: 654859365 --- third_party/shardy/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 0c55801a584a3c..55e16fed709356 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "cd675a1ba02c7a380d8d89ebb9de743405f3c3e6" - SHARDY_SHA256 = "602cdc8f3ed86f45174ef7a1baa1f0560f58612d0160a1438dd36e88db539c8d" + SHARDY_COMMIT = "58fc775e94b0e7b0f127848e151f5c0dc4c64435" + SHARDY_SHA256 = "446714f551b9df42b99c6afc913b7078b69c6985f300116f9bcf2279ab4cb623" tf_http_archive( name = "shardy", From 5e2fc1f94a879f8c455417c2c56e672d99c5ed65 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Mon, 22 Jul 2024 12:55:43 -0700 Subject: [PATCH 065/376] Add physical device ordinal to the run options. PiperOrigin-RevId: 654863283 --- xla/client/local_client.cc | 36 +++++++++++++++++++------ xla/executable_run_options.cc | 10 +++++++ xla/executable_run_options.h | 20 +++++++++++--- xla/pjrt/pjrt_stream_executor_client.cc | 3 +++ 4 files changed, 57 insertions(+), 12 deletions(-) diff --git a/xla/client/local_client.cc b/xla/client/local_client.cc index e00f39143bb6ee..c388dc478d7fde 100644 --- a/xla/client/local_client.cc +++ b/xla/client/local_client.cc @@ -68,18 +68,35 @@ absl::Status LocalExecutable::ValidateExecutionOptions( stream_platform->Name(), backend_->platform()->Name()); } - // Cannot specify device_ordinal with a stream. The stream determines these - // values. - if (run_options.device_ordinal() != -1) { + // The device ordinal (if provided) should match the ordinal of the device + // the stream belongs to. + int physical_device_ordinal = -1; + if (run_options.physical_device_ordinal() != -1) { + physical_device_ordinal = run_options.physical_device_ordinal(); + if (run_options.device_ordinal() == -1) { + return InvalidArgument( + "The logical device ordinal is required if the physical device " + "ordinal is specified."); + } + } else if (run_options.device_ordinal() != -1) { + // If the physical device ordinal is not specified, it is the same as the + // logical device ordinal if it is given. + physical_device_ordinal = run_options.device_ordinal(); + } + if (physical_device_ordinal != -1 && + physical_device_ordinal != + run_options.stream()->parent()->device_ordinal()) { return InvalidArgument( - "cannot set both device ordinal and stream options in " - "ExecutableRunOptions; the stream determines the device ordinal"); + "The physical device ordinal does not match the ordinal of the " + "device the stream belongs to."); } } // Verify that the device the executable was built for is equivalent // to the device it will run on. - int run_device_ordinal = run_options.device_ordinal(); + int run_device_ordinal = run_options.physical_device_ordinal() != -1 + ? run_options.physical_device_ordinal() + : run_options.device_ordinal(); if (run_device_ordinal == -1) { run_device_ordinal = run_options.stream() != nullptr ? run_options.stream()->parent()->device_ordinal() @@ -154,8 +171,11 @@ LocalExecutable::RunHelper(const absl::Span argument_shapes, // `service_options` (otherwise we will end up using a returned stream in // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if" // scope. - TF_ASSIGN_OR_RETURN( - stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_)); + TF_ASSIGN_OR_RETURN(stream, BorrowStreamForDevice( + run_options.physical_device_ordinal() != -1 + ? run_options.physical_device_ordinal() + : run_options.device_ordinal(), + backend_)); run_options.set_stream(stream.get()); } if (run_options.allocator() == nullptr) { diff --git a/xla/executable_run_options.cc b/xla/executable_run_options.cc index cc53bed2df5358..0a6af339982e48 100644 --- a/xla/executable_run_options.cc +++ b/xla/executable_run_options.cc @@ -41,6 +41,16 @@ ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal( int ExecutableRunOptions::device_ordinal() const { return device_ordinal_; } +ExecutableRunOptions& ExecutableRunOptions::set_physical_device_ordinal( + int physical_device_ordinal) { + physical_device_ordinal_ = physical_device_ordinal; + return *this; +} + +int ExecutableRunOptions::physical_device_ordinal() const { + return physical_device_ordinal_; +} + ExecutableRunOptions& ExecutableRunOptions::set_allocator( stream_executor::DeviceMemoryAllocator* allocator) { allocator_ = allocator; diff --git a/xla/executable_run_options.h b/xla/executable_run_options.h index c6a4897c2067ec..26d1a09668015c 100644 --- a/xla/executable_run_options.h +++ b/xla/executable_run_options.h @@ -128,13 +128,24 @@ class ExecutableRunOptions { stream_executor::DeviceMemoryAllocator* allocator() const; // If set, this is the device to run the computation on. Valid device_ordinal - // values are: 0 to # of devices - 1. These values are identical to the device - // ordinal values used by StreamExecutor. The device must be of the same type - // as the executable was compiled for. A value of -1 indicates this option has - // not been set. + // values are: 0 to # of devices - 1. These are the logical device ordinals, + // since multiple logical devices could reside on the same physical device, + // e.g., virtual GPUs. If there is only one logical device on a physical + // device, then these values are identical to the device ordinal values used + // by StreamExecutor. The device must be of the same type as the executable + // was compiled for. A value of -1 indicates this option has not been set. ExecutableRunOptions& set_device_ordinal(int device_ordinal); int device_ordinal() const; + // If set, this is the physical device to run the computation on. These values + // are identical to the device ordinal values used by StreamExecutor. The + // device must be of the same type as the executable was compiled for. A value + // of -1 indicates this option has not been set, in which case the physical + // device ordinal is the same as the logical device ordinal. + ExecutableRunOptions& set_physical_device_ordinal( + int physical_device_ordinal); + int physical_device_ordinal() const; + // If set, this is the stream to run the computation on. The platform of the // stream must match the platform the executable was built for. A value of // nullptr indicates the option has not been set. @@ -240,6 +251,7 @@ class ExecutableRunOptions { private: stream_executor::DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; + int physical_device_ordinal_ = -1; const DeviceAssignment* device_assignment_ = nullptr; stream_executor::Stream* stream_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index b79066a6061f38..464ba2ef10359c 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -2860,6 +2860,9 @@ PjRtStreamExecutorLoadedExecutable::EnqueueExecution( ExecutableRunOptions run_options; run_options.set_stream(device_state->compute_stream()); + run_options.set_device_ordinal(device_state->local_device_id().value()); + run_options.set_physical_device_ordinal( + device_state->local_hardware_id().value()); run_options.set_host_to_device_stream(device_state->host_to_device_stream()); run_options.set_device_to_host_stream(device_state->GetDeviceToHostStream()); run_options.set_allocator(client_->allocator()); From 03c7366b9e65044a3084348479c5feaf09e2c004 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 22 Jul 2024 13:46:31 -0700 Subject: [PATCH 066/376] [xla:cpu] Support run time pointer sizes for sorted elements PiperOrigin-RevId: 654881781 --- xla/service/cpu/runtime/sort_thunk.cc | 275 ++++++++++---------------- 1 file changed, 109 insertions(+), 166 deletions(-) diff --git a/xla/service/cpu/runtime/sort_thunk.cc b/xla/service/cpu/runtime/sort_thunk.cc index d07e734b46f164..959b096b8ba18c 100644 --- a/xla/service/cpu/runtime/sort_thunk.cc +++ b/xla/service/cpu/runtime/sort_thunk.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include #include +#include #include #include #include #include -#include #include #include "absl/algorithm/container.h" @@ -125,129 +125,112 @@ SortThunk::SortThunk(Info info, absl::Span inputs, namespace { // We use a lot of template metaprogramming below to be able to construct -// iterators with statically known element sizes. We support a limited set of -// template instantiations that we need in practice. +// iterators with statically known number of compared elements. We support a +// limited set of template instantiations that we need in practice. + +// The size of the largest element we support (std::complex). +static constexpr size_t kMaxElementSize = 16; // Forward declare reference type defined below. -template +template struct Ref; // Value type to store values loaded from the input buffers. -template +template struct Value { - Value(const Ref& ref); // NOLINT + Value(const Ref& ref); // NOLINT - template - const void* compared_value() const { - return &std::get(value); - } + const void* compared_value(size_t i) const { return value[i].data(); } - std::tuple value; + // Use properly aligned byte array to store primitive values. + using ValueStorage = std::array; + alignas(alignof(std::max_align_t)) std::array value; + std::array value_sizes; }; // Reference to values stored in the input buffers. -template +template struct Ref { - explicit Ref(std::tuple ptr) : ptr(ptr) {} + Ref(std::array ptr, std::array ptr_sizes) + : ptr(ptr), ptr_sizes(ptr_sizes) {} - Ref& operator=(const Value& value); - Ref& operator=(const Ref& other); + Ref& operator=(const Value& value); + Ref& operator=(const Ref& other); - template - const void* compared_value() const { - return std::get(ptr); - } + const void* compared_value(size_t i) const { return ptr[i]; } - std::tuple ptr; + std::array ptr; + std::array ptr_sizes; }; -// Value to reference assignment. -template -static void Assign(Ref& ref, const Value& value, - std::index_sequence) { - ((*std::get(ref.ptr) = std::get(value.value)), ...); -} - -// Reference to reference assignment. -template -static void Assign(Ref& ref, const Ref& other, - std::index_sequence) { - ((*std::get(ref.ptr) = *std::get(other.ptr)), ...); +template +Value::Value(const Ref& ref) : value_sizes(ref.ptr_sizes) { + for (size_t i = 0; i < n; ++i) { + std::memcpy(value[i].data(), ref.ptr[i], ref.ptr_sizes[i]); + } } -template -Value::Value(const Ref& ref) - : value(std::apply([](auto*... p) { return std::make_tuple(*p...); }, - ref.ptr)) {} - -template -Ref& Ref::operator=(const Value& value) { - Assign(*this, value, std::make_index_sequence{}); +template +Ref& Ref::operator=(const Value& value) { + DCHECK(ptr_sizes == value.value_sizes); + for (size_t i = 0; i < n; ++i) { + std::memcpy(ptr[i], value.value[i].data(), value.value_sizes[i]); + } return *this; } -template -Ref& Ref::operator=(const Ref& other) { - Assign(*this, other, std::make_index_sequence{}); +template +Ref& Ref::operator=(const Ref& other) { + DCHECK(ptr_sizes == other.ptr_sizes); + for (size_t i = 0; i < n; ++i) { + std::memcpy(ptr[i], other.ptr[i], other.ptr_sizes[i]); + } return *this; } // Swap function required by `std::sort` and `std::stable_sort` implementations. -template -void swap(const Ref& lhs, const Ref& rhs) { - std::swap(*std::get<0>(lhs.ptr), *std::get<0>(rhs.ptr)); -} - -template -void swap(const Ref& lhs, const Ref& rhs) { - std::swap(*std::get<0>(lhs.ptr), *std::get<0>(rhs.ptr)); - std::swap(*std::get<1>(lhs.ptr), *std::get<1>(rhs.ptr)); -} - -// Extracts pointers to compared elements and packs them in the layout expected -// by the comparator function. -template -std::array ComparatorData1(const Lhs& lhs, const Rhs& rhs) { - return {lhs.template compared_value<0>(), rhs.template compared_value<0>()}; -} - -template -std::array ComparatorData2(const Lhs& lhs, const Rhs& rhs) { - return {lhs.template compared_value<0>(), rhs.template compared_value<0>(), - lhs.template compared_value<1>(), rhs.template compared_value<1>()}; +template +void swap(const Ref& lhs, const Ref& rhs) { + for (size_t i = 0; i < n; ++i) { + std::array tmp; + std::memcpy(tmp.data(), lhs.ptr[i], lhs.ptr_sizes[i]); + std::memcpy(lhs.ptr[i], rhs.ptr[i], rhs.ptr_sizes[i]); + std::memcpy(rhs.ptr[i], tmp.data(), lhs.ptr_sizes[i]); + } } -// A pointer (tuple of pointers) to the input data. -template +// An array of pointers to the input data. +template struct Ptr { using difference_type = std::ptrdiff_t; Ptr() = default; - explicit Ptr(Ts*... ptrs) : ptrs(ptrs...) {} - explicit Ptr(std::tuple ptrs) : ptrs(ptrs) {} - Ref operator*() const { return Ref{ptrs}; } + Ptr(std::array ptr, std::array ptr_sizes) + : ptr(ptr), ptr_sizes(ptr_sizes) {} + + Ref operator*() const { return Ref{ptr, ptr_sizes}; } - Ptr& operator+=(difference_type n) { - ptrs = std::apply( - [&](auto*... p) { return std::make_tuple(p + n...); }, ptrs); + Ptr& operator+=(difference_type diff) { + for (size_t i = 0; i < n; ++i) ptr[i] += diff * ptr_sizes[i]; return *this; } - Ptr& operator-=(difference_type n) { - ptrs = std::apply( - [&](auto*... p) { return std::make_tuple(p - n...); }, ptrs); + Ptr& operator-=(difference_type diff) { + for (size_t i = 0; i < n; ++i) ptr[i] -= diff * ptr_sizes[i]; return *this; } - Ptr operator+(difference_type n) const { - return Ptr{std::apply( - [&](auto*... p) { return std::make_tuple(p + n...); }, ptrs)}; + Ptr operator+(difference_type diff) const { + std::array upd; + for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] + diff * ptr_sizes[i]; + return Ptr{upd, ptr_sizes}; } - Ptr operator-(difference_type n) const { - return Ptr{std::apply( - [&](auto*... p) { return std::make_tuple(p - n...); }, ptrs)}; + Ptr operator-(difference_type diff) const { + std::array upd; + for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] - diff * ptr_sizes[i]; + return Ptr{upd, ptr_sizes}; } // In all comparison operators defined below we use only the ptr at index 0, @@ -255,44 +238,34 @@ struct Ptr { // implementation detail of sort iterator. difference_type operator-(const Ptr& rhs) const { - return std::get<0>(ptrs) - std::get<0>(rhs.ptrs); + DCHECK(ptr_sizes == rhs.ptr_sizes); + return (ptr[0] - rhs.ptr[0]) / ptr_sizes[0]; } - bool operator==(const Ptr& rhs) const { - return std::get<0>(ptrs) == std::get<0>(rhs.ptrs); - } - bool operator!=(const Ptr& rhs) const { - return std::get<0>(ptrs) != std::get<0>(rhs.ptrs); - } - bool operator>(const Ptr& rhs) const { - return std::get<0>(ptrs) > std::get<0>(rhs.ptrs); - } - bool operator<(const Ptr& rhs) const { - return std::get<0>(ptrs) < std::get<0>(rhs.ptrs); - } - bool operator>=(const Ptr& rhs) const { - return std::get<0>(ptrs) >= std::get<0>(rhs.ptrs); - } - bool operator<=(const Ptr& rhs) const { - return std::get<0>(ptrs) <= std::get<0>(rhs.ptrs); - } + bool operator==(const Ptr& rhs) const { return ptr[0] == rhs.ptr[0]; } + bool operator!=(const Ptr& rhs) const { return ptr[0] != rhs.ptr[0]; } + bool operator>(const Ptr& rhs) const { return ptr[0] > rhs.ptr[0]; } + bool operator<(const Ptr& rhs) const { return ptr[0] < rhs.ptr[0]; } + bool operator>=(const Ptr& rhs) const { return ptr[0] >= rhs.ptr[0]; } + bool operator<=(const Ptr& rhs) const { return ptr[0] <= rhs.ptr[0]; } - std::tuple ptrs; + std::array ptr; // pointers into the input buffers + std::array ptr_sizes; // pointers sizes in bytes }; // We rely on `std::sort` and `std::stable_sort` to sort the raw data. We sort // multiple input buffers together using the same comparator function, so we // need to provide a custom iterator that can access the data of all input // buffers at the same time and swap elements in them. -template +template class SortIterator { public: using iterator_category = std::random_access_iterator_tag; using difference_type = std::ptrdiff_t; - using value_type = Value; - using reference = Ref; - using pointer = Ptr; + using value_type = Value; + using reference = Ref; + using pointer = Ptr; SortIterator() = default; SortIterator(pointer ptr, difference_type stride) @@ -309,13 +282,13 @@ class SortIterator { return (ptr_ - rhs.ptr_) / stride_; } - SortIterator& operator+=(difference_type n) { - ptr_ += n * stride_; + SortIterator& operator+=(difference_type diff) { + ptr_ += diff * stride_; return *this; } - SortIterator& operator-=(difference_type n) { - ptr_ -= n * stride_; + SortIterator& operator-=(difference_type diff) { + ptr_ -= diff * stride_; return *this; } @@ -329,12 +302,12 @@ class SortIterator { return *this; } - SortIterator operator+(difference_type n) const { - return SortIterator(ptr_ + n * stride_, stride_); + SortIterator operator+(difference_type diff) const { + return SortIterator(ptr_ + diff * stride_, stride_); } - SortIterator operator-(difference_type n) const { - return SortIterator(ptr_ - n * stride_, stride_); + SortIterator operator-(difference_type diff) const { + return SortIterator(ptr_ - diff * stride_, stride_); } bool operator==(const SortIterator& rhs) const { return ptr_ == rhs.ptr_; } @@ -383,42 +356,32 @@ static SortDims GetSortDims(absl::Span dimensions, num_iterations}; } -// Sorts one input buffer of type `T0` inplace. -template +// Sorts `n` buffers in place. +template static void SortInplace(const SortDims& sort_dims, int64_t offset, - absl::Span data, bool is_stable, + absl::Span data, + absl::Span shapes, bool is_stable, SortThunk::LessThan* less_than) { - T0* base0 = reinterpret_cast(data[0].opaque()); - - auto compare = [&](const auto& a, const auto& b) { - auto data = ComparatorData1(a, b); - return (*less_than)(data.data()); - }; + std::array ptr; + std::array ptr_sizes; - SortIterator begin(Ptr(base0 + offset), - /*stride=*/sort_dims.inner_dim_size); - if (is_stable) { - std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); - } else { - std::sort(begin, begin + sort_dims.sort_dim_size, compare); + for (size_t i = 0; i < n; ++i) { + std::byte* base = reinterpret_cast(data[i].opaque()); + ptr_sizes[i] = primitive_util::ByteWidth(shapes[i].element_type()); + ptr[i] = base + offset * ptr_sizes[i]; } -} - -// Sorts two input buffers of type `T0` and `T1` inplace. -template -static void SortInplace(const SortDims& sort_dims, int64_t offset, - absl::Span data, bool is_stable, - SortThunk::LessThan* less_than) { - T0* base0 = reinterpret_cast(data[0].opaque()); - T1* base1 = reinterpret_cast(data[1].opaque()); auto compare = [&](const auto& a, const auto& b) { - auto data = ComparatorData2(a, b); + std::array data; + for (size_t i = 0, j = 0; i < n; i += 1, j += 2) { + data[j] = a.compared_value(i); + data[j + 1] = b.compared_value(i); + } return (*less_than)(data.data()); }; - SortIterator begin(Ptr(base0 + offset, base1 + offset), - /*stride=*/sort_dims.inner_dim_size); + SortIterator begin(Ptr(ptr, ptr_sizes), + /*stride=*/sort_dims.inner_dim_size); if (is_stable) { std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); } else { @@ -435,37 +398,17 @@ static absl::Status SortInplace(absl::Span data, // shape to get the sort dimensions. SortDims sort_dims = GetSortDims(shapes[0].dimensions(), dimension); - // Type tags for specializing the `sort` functor. Instead of specializing for - // each individual primitive type, we use a byte array of correct size to - // avoid the code bloat, as we use external comparator function anyway and - // don't compare the values directly. - using _4_bytes = std::array; - - // Collect byte sizes of element types of all inputs. - absl::InlinedVector byte_sizes; - byte_sizes.reserve(data.size()); - for (const Shape& shape : shapes) { - byte_sizes.push_back(primitive_util::ByteWidth(shape.element_type())); - } - - auto is_byte_sizes = [&](auto... sizes) { - return absl::c_equal(byte_sizes, absl::InlinedVector{ - static_cast(sizes)...}); - }; - // Iterate over all the 1-dimensional slices of the buffers and sort them. for (int64_t i = 0; i < sort_dims.num_iterations; ++i) { int64_t inner_idx = i % sort_dims.inner_dim_size; int64_t offset = inner_idx + (i - inner_idx) * sort_dims.sort_dim_size; - if (is_byte_sizes(4)) { - SortInplace<_4_bytes>(sort_dims, offset, data, is_stable, less_than); - } else if (is_byte_sizes(4, 4)) { - SortInplace<_4_bytes, _4_bytes>(sort_dims, offset, data, is_stable, - less_than); + if (data.size() == 1) { + SortInplace<1>(sort_dims, offset, data, shapes, is_stable, less_than); + } else if (data.size() == 2) { + SortInplace<2>(sort_dims, offset, data, shapes, is_stable, less_than); } else { - return Internal("Unsupported sort element byte widths [%s]", - absl::StrJoin(byte_sizes, ",")); + return Internal("Unsupported number of sorted inputs: %d", data.size()); } } From e9f85670989b7c25f7810bc4db23b630a79ea7eb Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 22 Jul 2024 13:47:14 -0700 Subject: [PATCH 067/376] Pass GpuContext to GpuTimer creation instead of GpuExecutor, as that's the only thing from GpuExecutor that's needed. PiperOrigin-RevId: 654882047 --- xla/stream_executor/cuda/cuda_executor.cc | 3 +- xla/stream_executor/gpu/gpu_timer.cc | 39 ++++++++++------------- xla/stream_executor/gpu/gpu_timer.h | 14 ++++---- xla/stream_executor/rocm/rocm_executor.cc | 3 +- 4 files changed, 28 insertions(+), 31 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 1aa01f5469aa2f..14fa0ed509aa4b 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -269,7 +269,8 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, absl::StatusOr> GpuExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { // TODO(b/301020144) Move this all to the appropriate Executor class. - return GpuTimer::CreateEventBasedTimer(stream, use_delay_kernel); + return GpuTimer::CreateEventBasedTimer(stream, gpu_context(), + use_delay_kernel); } bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { diff --git a/xla/stream_executor/gpu/gpu_timer.cc b/xla/stream_executor/gpu/gpu_timer.cc index 325df3ec083243..af649aea935e9d 100644 --- a/xla/stream_executor/gpu/gpu_timer.cc +++ b/xla/stream_executor/gpu/gpu_timer.cc @@ -66,14 +66,11 @@ bool ShouldLaunchDelayKernel() { return value; } -absl::Status CreateGpuTimerParts(Stream* real_stream, bool use_delay_kernel, - GpuExecutor*& parent, +absl::Status CreateGpuTimerParts(GpuStream* stream, bool use_delay_kernel, + GpuContext* context, GpuEventHandle& start_event, GpuEventHandle& stop_event, GpuSemaphore& semaphore) { - GpuStream* stream = AsGpuStream(real_stream); - parent = stream->parent(); - GpuContext* context = parent->gpu_context(); TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &start_event, GpuDriver::EventFlags::kDefault)); TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &stop_event, @@ -88,27 +85,27 @@ absl::Status CreateGpuTimerParts(Stream* real_stream, bool use_delay_kernel, TF_ASSIGN_OR_RETURN(bool is_supported, DelayKernelIsSupported(stream)); if (is_supported) { - TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(real_stream)); + TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(stream)); } } // The start event goes after the delay kernel in the stream - TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent->gpu_context(), start_event, - stream->gpu_stream())); + TF_RETURN_IF_ERROR( + GpuDriver::RecordEvent(context, start_event, stream->gpu_stream())); return absl::OkStatus(); } } // namespace absl::StatusOr> -GpuTimer::CreateEventBasedTimer(Stream* stream, bool use_delay_kernel) { - GpuExecutor* parent = nullptr; +GpuTimer::CreateEventBasedTimer(GpuStream* stream, GpuContext* context, + bool use_delay_kernel) { GpuEventHandle start_event = nullptr; GpuEventHandle stop_event = nullptr; GpuSemaphore semaphore{}; - TF_RETURN_IF_ERROR(CreateGpuTimerParts(stream, use_delay_kernel, parent, + TF_RETURN_IF_ERROR(CreateGpuTimerParts(stream, use_delay_kernel, context, start_event, stop_event, semaphore)); - return std::make_unique(parent, start_event, stop_event, - AsGpuStream(stream), std::move(semaphore)); + return std::make_unique(context, start_event, stop_event, stream, + std::move(semaphore)); } /*static*/ void GpuTimer::ReturnRandomDurationsForTesting() { @@ -116,26 +113,25 @@ GpuTimer::CreateEventBasedTimer(Stream* stream, bool use_delay_kernel) { } GpuTimer::~GpuTimer() { - GpuContext* context = parent_->gpu_context(); if (semaphore_ && !is_stopped_) { // Signal the delay kernel that it can exit *semaphore_ = GpuSemaphoreState::kRelease; // Wait for the delay kernel to exit before destroying the value that it is // watching. absl::Status status = - GpuDriver::SynchronizeStream(context, stream_->gpu_stream()); + GpuDriver::SynchronizeStream(context_, stream_->gpu_stream()); if (!status.ok()) { LOG(ERROR) << status; } } if (start_event_ != nullptr) { - absl::Status status = GpuDriver::DestroyEvent(context, &start_event_); + absl::Status status = GpuDriver::DestroyEvent(context_, &start_event_); if (!status.ok()) { LOG(ERROR) << status; } } if (stop_event_ != nullptr) { - absl::Status status = GpuDriver::DestroyEvent(context, &stop_event_); + absl::Status status = GpuDriver::DestroyEvent(context_, &stop_event_); if (!status.ok()) { LOG(ERROR) << status; } @@ -146,8 +142,8 @@ absl::StatusOr GpuTimer::GetElapsedDuration() { if (is_stopped_) { return absl::InternalError("Measuring inactive timer"); } - TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent_->gpu_context(), stop_event_, - stream_->gpu_stream())); + TF_RETURN_IF_ERROR( + GpuDriver::RecordEvent(context_, stop_event_, stream_->gpu_stream())); // If we launched the delay kernel then check if it already timed out. if (semaphore_) { if (*semaphore_ == GpuSemaphoreState::kTimedOut) { @@ -161,9 +157,8 @@ absl::StatusOr GpuTimer::GetElapsedDuration() { } } float elapsed_milliseconds = NAN; - if (!GpuDriver::GetEventElapsedTime(parent_->gpu_context(), - &elapsed_milliseconds, start_event_, - stop_event_)) { + if (!GpuDriver::GetEventElapsedTime(context_, &elapsed_milliseconds, + start_event_, stop_event_)) { return absl::InternalError("Error stopping the timer"); } is_stopped_ = true; diff --git a/xla/stream_executor/gpu/gpu_timer.h b/xla/stream_executor/gpu/gpu_timer.h index 3cfd3cfc34efb0..efde218c8965ed 100644 --- a/xla/stream_executor/gpu/gpu_timer.h +++ b/xla/stream_executor/gpu/gpu_timer.h @@ -37,7 +37,7 @@ class DeterminismTest; namespace stream_executor { namespace gpu { -class GpuExecutor; +class GpuContext; class GpuStream; // When a timer is created it launches a delay kernel into the given stream and @@ -49,19 +49,19 @@ class GpuStream; class GpuTimer : public EventBasedTimer { public: static absl::StatusOr> CreateEventBasedTimer( - Stream* stream, bool use_delay_kernel); + GpuStream* stream, GpuContext* context, bool use_delay_kernel); - explicit GpuTimer(GpuExecutor* parent, GpuEventHandle start_event, + explicit GpuTimer(GpuContext* context, GpuEventHandle start_event, GpuEventHandle stop_event, GpuStream* stream, GpuSemaphore semaphore = {}) - : parent_(parent), + : context_(context), start_event_(start_event), stop_event_(stop_event), stream_(stream), semaphore_(std::move(semaphore)) {} GpuTimer(GpuTimer&& other) - : parent_(other.parent_), + : context_(other.context_), start_event_(std::exchange(other.start_event_, nullptr)), stop_event_(std::exchange(other.stop_event_, nullptr)), stream_(other.stream_), @@ -69,7 +69,7 @@ class GpuTimer : public EventBasedTimer { GpuTimer& operator=(GpuTimer&& other) { if (this != &other) { - parent_ = other.parent_; + context_ = other.context_; start_event_ = std::exchange(other.start_event_, nullptr); stop_event_ = std::exchange(other.stop_event_, nullptr); stream_ = other.stream_; @@ -83,7 +83,7 @@ class GpuTimer : public EventBasedTimer { absl::StatusOr GetElapsedDuration() override; private: - GpuExecutor* parent_; + GpuContext* context_; GpuEventHandle start_event_ = nullptr; GpuEventHandle stop_event_ = nullptr; GpuStream* stream_; diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 8d5aa981419e59..c1da843d7a00cf 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -187,7 +187,8 @@ GpuExecutor::CreateOrShareConstant(Stream* stream, absl::StatusOr> GpuExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { // TODO(b/301020144) Move this all to the appropriate Executor class. - return GpuTimer::CreateEventBasedTimer(stream, use_delay_kernel); + return GpuTimer::CreateEventBasedTimer(stream, gpu_context(), + use_delay_kernel); } bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { From b39cbe20b49f2401ca678b95f009402120764849 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 22 Jul 2024 14:02:12 -0700 Subject: [PATCH 068/376] Use GpuEvent class instead of reimplementing portions of it in GpuTimer. PiperOrigin-RevId: 654886984 --- xla/stream_executor/cuda/cuda_executor.cc | 16 +++++-- xla/stream_executor/gpu/BUILD | 2 +- xla/stream_executor/gpu/gpu_event.cc | 6 ++- xla/stream_executor/gpu/gpu_event.h | 2 +- xla/stream_executor/gpu/gpu_executor.h | 4 ++ xla/stream_executor/gpu/gpu_timer.cc | 58 +++++++---------------- xla/stream_executor/gpu/gpu_timer.h | 18 +++---- xla/stream_executor/rocm/rocm_executor.cc | 16 +++++-- 8 files changed, 61 insertions(+), 61 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 14fa0ed509aa4b..c05a9189b21d60 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -269,8 +269,11 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, absl::StatusOr> GpuExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { // TODO(b/301020144) Move this all to the appropriate Executor class. - return GpuTimer::CreateEventBasedTimer(stream, gpu_context(), - use_delay_kernel); + TF_ASSIGN_OR_RETURN(auto start_event, CreateGpuEvent(/*allow_timing=*/true)); + TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true)); + return GpuTimer::CreateEventBasedTimer( + stream, gpu_context(), use_delay_kernel, std::move(start_event), + std::move(stop_event)); } bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { @@ -717,12 +720,17 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device, return absl::OkStatus(); } -absl::StatusOr> GpuExecutor::CreateEvent() { +absl::StatusOr> GpuExecutor::CreateGpuEvent( + bool allow_timing) { auto gpu_event = std::make_unique(this); - TF_RETURN_IF_ERROR(gpu_event->Init()); + TF_RETURN_IF_ERROR(gpu_event->Init(allow_timing)); return std::move(gpu_event); } +absl::StatusOr> GpuExecutor::CreateEvent() { + return CreateGpuEvent(/*allow_timing=*/false); +} + absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { auto gpu_stream = std::make_unique(this); diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 8175356bada4be..c7ff3caf62904a 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -409,7 +409,7 @@ gpu_only_cc_library( ], deps = [ ":gpu_driver_header", - ":gpu_executor_header", + ":gpu_event", ":gpu_semaphore", ":gpu_stream", ":gpu_types_header", diff --git a/xla/stream_executor/gpu/gpu_event.cc b/xla/stream_executor/gpu/gpu_event.cc index 83c8c16e298d6e..e2aec087ef49e0 100644 --- a/xla/stream_executor/gpu/gpu_event.cc +++ b/xla/stream_executor/gpu/gpu_event.cc @@ -31,9 +31,11 @@ GpuEvent::GpuEvent(GpuExecutor* parent) GpuEvent::~GpuEvent() { Destroy().IgnoreError(); } -absl::Status GpuEvent::Init() { +absl::Status GpuEvent::Init(bool allow_timing) { return GpuDriver::InitEvent(parent_->gpu_context(), &gpu_event_, - GpuDriver::EventFlags::kDisableTiming); + allow_timing + ? GpuDriver::EventFlags::kDefault + : GpuDriver::EventFlags::kDisableTiming); } absl::Status GpuEvent::Destroy() { diff --git a/xla/stream_executor/gpu/gpu_event.h b/xla/stream_executor/gpu/gpu_event.h index 08a5d1c4f76ed5..66b9bbafd7abd4 100644 --- a/xla/stream_executor/gpu/gpu_event.h +++ b/xla/stream_executor/gpu/gpu_event.h @@ -34,7 +34,7 @@ class GpuEvent : public Event { ~GpuEvent() override; // Populates the CUDA-platform-specific elements of this object. - absl::Status Init(); + absl::Status Init(bool allow_timing); // Deallocates any platform-specific elements of this object. This is broken // out (not part of the destructor) to allow for error reporting. diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index 275f0c99e87ea0..ad4a7b7b3103c9 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -66,6 +66,7 @@ class StreamExecutor; namespace gpu { +class GpuEvent; class GpuKernel; class GpuCommandBuffer; class GpuStream; @@ -331,6 +332,9 @@ class GpuExecutor : public StreamExecutorCommon { bool UnloadGpuBinary(const void* gpu_binary) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); + // Creates a GpuEvent for the given stream. + absl::StatusOr> CreateGpuEvent(bool allow_timing); + // Guards the on-disk-module mapping. absl::Mutex disk_modules_mu_; diff --git a/xla/stream_executor/gpu/gpu_timer.cc b/xla/stream_executor/gpu/gpu_timer.cc index af649aea935e9d..9bc8f73ee38dca 100644 --- a/xla/stream_executor/gpu/gpu_timer.cc +++ b/xla/stream_executor/gpu/gpu_timer.cc @@ -30,10 +30,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" -#include "absl/utility/utility.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer_kernel.h" @@ -66,16 +65,14 @@ bool ShouldLaunchDelayKernel() { return value; } -absl::Status CreateGpuTimerParts(GpuStream* stream, bool use_delay_kernel, - GpuContext* context, - GpuEventHandle& start_event, - GpuEventHandle& stop_event, - GpuSemaphore& semaphore) { - TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &start_event, - GpuDriver::EventFlags::kDefault)); - TF_RETURN_IF_ERROR(GpuDriver::InitEvent(context, &stop_event, - GpuDriver::EventFlags::kDefault)); - CHECK(start_event != nullptr && stop_event != nullptr); +} // namespace + +absl::StatusOr> +GpuTimer::CreateEventBasedTimer(GpuStream* stream, GpuContext* context, + bool use_delay_kernel, + std::unique_ptr start_event, + std::unique_ptr stop_event) { + GpuSemaphore semaphore{}; if (!use_delay_kernel) { LOG(WARNING) << "Skipping the delay kernel, measurement accuracy will be reduced"; @@ -90,21 +87,10 @@ absl::Status CreateGpuTimerParts(GpuStream* stream, bool use_delay_kernel, } // The start event goes after the delay kernel in the stream - TF_RETURN_IF_ERROR( - GpuDriver::RecordEvent(context, start_event, stream->gpu_stream())); - return absl::OkStatus(); -} -} // namespace + TF_RETURN_IF_ERROR(start_event->Record(stream->gpu_stream())); -absl::StatusOr> -GpuTimer::CreateEventBasedTimer(GpuStream* stream, GpuContext* context, - bool use_delay_kernel) { - GpuEventHandle start_event = nullptr; - GpuEventHandle stop_event = nullptr; - GpuSemaphore semaphore{}; - TF_RETURN_IF_ERROR(CreateGpuTimerParts(stream, use_delay_kernel, context, - start_event, stop_event, semaphore)); - return std::make_unique(context, start_event, stop_event, stream, + return std::make_unique(context, std::move(start_event), + std::move(stop_event), stream, std::move(semaphore)); } @@ -124,26 +110,15 @@ GpuTimer::~GpuTimer() { LOG(ERROR) << status; } } - if (start_event_ != nullptr) { - absl::Status status = GpuDriver::DestroyEvent(context_, &start_event_); - if (!status.ok()) { - LOG(ERROR) << status; - } - } - if (stop_event_ != nullptr) { - absl::Status status = GpuDriver::DestroyEvent(context_, &stop_event_); - if (!status.ok()) { - LOG(ERROR) << status; - } - } + start_event_.reset(); + stop_event_.reset(); } absl::StatusOr GpuTimer::GetElapsedDuration() { if (is_stopped_) { return absl::InternalError("Measuring inactive timer"); } - TF_RETURN_IF_ERROR( - GpuDriver::RecordEvent(context_, stop_event_, stream_->gpu_stream())); + TF_RETURN_IF_ERROR(stop_event_->Record(stream_->gpu_stream())); // If we launched the delay kernel then check if it already timed out. if (semaphore_) { if (*semaphore_ == GpuSemaphoreState::kTimedOut) { @@ -158,7 +133,8 @@ absl::StatusOr GpuTimer::GetElapsedDuration() { } float elapsed_milliseconds = NAN; if (!GpuDriver::GetEventElapsedTime(context_, &elapsed_milliseconds, - start_event_, stop_event_)) { + start_event_->gpu_event(), + stop_event_->gpu_event())) { return absl::InternalError("Error stopping the timer"); } is_stopped_ = true; diff --git a/xla/stream_executor/gpu/gpu_timer.h b/xla/stream_executor/gpu/gpu_timer.h index efde218c8965ed..f8ce6587e39afd 100644 --- a/xla/stream_executor/gpu/gpu_timer.h +++ b/xla/stream_executor/gpu/gpu_timer.h @@ -23,7 +23,7 @@ limitations under the License. #include "absl/time/time.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/stream.h" @@ -49,14 +49,16 @@ class GpuStream; class GpuTimer : public EventBasedTimer { public: static absl::StatusOr> CreateEventBasedTimer( - GpuStream* stream, GpuContext* context, bool use_delay_kernel); + GpuStream* stream, GpuContext* context, bool use_delay_kernel, + std::unique_ptr start_event, + std::unique_ptr stop_event); - explicit GpuTimer(GpuContext* context, GpuEventHandle start_event, - GpuEventHandle stop_event, GpuStream* stream, + explicit GpuTimer(GpuContext* context, std::unique_ptr start_event, + std::unique_ptr stop_event, GpuStream* stream, GpuSemaphore semaphore = {}) : context_(context), - start_event_(start_event), - stop_event_(stop_event), + start_event_(std::move(start_event)), + stop_event_(std::move(stop_event)), stream_(stream), semaphore_(std::move(semaphore)) {} @@ -84,8 +86,8 @@ class GpuTimer : public EventBasedTimer { private: GpuContext* context_; - GpuEventHandle start_event_ = nullptr; - GpuEventHandle stop_event_ = nullptr; + std::unique_ptr start_event_; + std::unique_ptr stop_event_; GpuStream* stream_; GpuSemaphore semaphore_; bool is_stopped_ = false; diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index c1da843d7a00cf..07dc7b69f8f3f7 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -187,8 +187,11 @@ GpuExecutor::CreateOrShareConstant(Stream* stream, absl::StatusOr> GpuExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { // TODO(b/301020144) Move this all to the appropriate Executor class. - return GpuTimer::CreateEventBasedTimer(stream, gpu_context(), - use_delay_kernel); + TF_ASSIGN_OR_RETURN(auto start_event, CreateGpuEvent(/*allow_timing=*/true)); + TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true)); + return GpuTimer::CreateEventBasedTimer( + stream, gpu_context(), use_delay_kernel, std::move(start_event), + std::move(stop_event)); } bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { @@ -630,12 +633,17 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device, return absl::OkStatus(); } -absl::StatusOr> GpuExecutor::CreateEvent() { +absl::StatusOr> GpuExecutor::CreateGpuEvent( + bool allow_timing) { auto gpu_event = std::make_unique(this); - TF_RETURN_IF_ERROR(gpu_event->Init()); + TF_RETURN_IF_ERROR(gpu_event->Init(allow_timing)); return std::move(gpu_event); } +absl::StatusOr> GpuExecutor::CreateEvent() { + return CreateGpuEvent(/*allow_timing=*/false); +} + absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { auto gpu_stream = std::make_unique(this); From 5bc3864fce6b143df155d99255e47b7f1a02c90b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jul 2024 14:07:17 -0700 Subject: [PATCH 069/376] update collective_permute_cycle_decomposer_test to include matmul operation PiperOrigin-RevId: 654889259 --- ...ollective_permute_cycle_decomposer_test.cc | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/xla/service/gpu/collective_permute_cycle_decomposer_test.cc b/xla/service/gpu/collective_permute_cycle_decomposer_test.cc index 379ecbe90afe3e..7f297ad1e615f1 100644 --- a/xla/service/gpu/collective_permute_cycle_decomposer_test.cc +++ b/xla/service/gpu/collective_permute_cycle_decomposer_test.cc @@ -134,35 +134,36 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) { check_metadata(cp2); } -TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithWhileLoop) { +TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) { const absl::string_view kModuleStr = R"( HloModule test while_cond { - param = (u32[], f32[]) parameter(0) + param = (u32[], f32[2,2], f32[2,2]) parameter(0) iter = u32[] get-tuple-element(param), index=0 - max_iter = u32[] constant(5) + max_iter = u32[] constant(3) ROOT cmp = pred[] compare(iter, max_iter), direction=LT } while_body { - param = (u32[], f32[]) parameter(0) + param = (u32[], f32[2,2], f32[2,2]) parameter(0) iter = u32[] get-tuple-element(param), index=0 - data = f32[] get-tuple-element(param), index=1 - ten = f32[] constant(10) - sum = f32[] add(data, ten) - cp = f32[] collective-permute(sum), channel_id=1, source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + data = f32[2,2] get-tuple-element(param), index=1 + weights = f32[2,2] get-tuple-element(param), index=2 + matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, rhs_contracting_dims={0} + cp = f32[2,2] collective-permute(matmul), channel_id=1, source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} iter_increment = u32[] constant(1) next_iter = u32[] add(iter, iter_increment) - ROOT result = (u32[], f32[]) tuple(next_iter, cp) + ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, cp, weights) } ENTRY test_computation { iter = u32[] constant(0) - data = f32[] parameter(0) - input = (u32[], f32[]) tuple(iter, data) - while_res = (u32[], f32[]) while(input), condition=while_cond, body=while_body - ROOT data_out = f32[] get-tuple-element(while_res), index=1 + data = f32[2,2] parameter(0) + weights = f32[2,2] parameter(1) + input = (u32[], f32[2,2], f32[2,2]) tuple(iter, data, weights) + while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body + ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1 } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, From 80f87213cd639f13e2c1d6882777b38aec0794f8 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Mon, 22 Jul 2024 14:09:03 -0700 Subject: [PATCH 070/376] [XLA:GPU] Uniquify only command buffer created instructions. PiperOrigin-RevId: 654889890 --- xla/hlo/ir/hlo_computation.cc | 4 +++ xla/hlo/ir/hlo_computation.h | 4 +++ xla/hlo/ir/hlo_instruction.cc | 8 +++++ xla/hlo/ir/hlo_instruction.h | 8 +++++ xla/hlo/ir/hlo_module.h | 12 +++++++ xla/service/gpu/command_buffer_scheduling.cc | 31 +++++++++++------ xla/service/gpu/command_buffer_scheduling.h | 2 +- .../gpu/command_buffer_scheduling_test.cc | 33 ++++++++++--------- 8 files changed, 75 insertions(+), 27 deletions(-) diff --git a/xla/hlo/ir/hlo_computation.cc b/xla/hlo/ir/hlo_computation.cc index 8aa5786ce7b0cf..025b1ce4f4388e 100644 --- a/xla/hlo/ir/hlo_computation.cc +++ b/xla/hlo/ir/hlo_computation.cc @@ -1747,6 +1747,10 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } +void HloComputation::UniquifyName(HloModule* module) { + UniquifyName(&module->computation_name_uniquer()); +} + HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) { auto instructions_in_computation = instructions(); auto it = absl::c_find_if( diff --git a/xla/hlo/ir/hlo_computation.h b/xla/hlo/ir/hlo_computation.h index 36e4b61dd3a843..956cf1abe1ede2 100644 --- a/xla/hlo/ir/hlo_computation.h +++ b/xla/hlo/ir/hlo_computation.h @@ -317,6 +317,10 @@ class HloComputation { // SetAndSanitizeName(). void UniquifyName(NameUniquer* name_uniquer); + // Use the given `module` to select a unique name for this computation based + // on computation's existing name. + void UniquifyName(HloModule* module); + // Prints a string representation of the computation. // // (We express the default options using an overload rather than a default diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 97aea062027a09..7ea85bf6c836f2 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -4985,6 +4985,14 @@ void HloInstruction::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } +void HloInstruction::UniquifyName(HloModule* module) { + UniquifyName(&module->instruction_name_uniquer()); +} + +void HloInstruction::UniquifyId(HloModule* module) { + SetUniqueId(module->NewUniqueInstructionId()); +} + void HloInstruction::SortInstructionUsersAndControlLists( const MappedPtrContainerSorter::MapPtrFn& map_fn, const HloInstruction& sorted_instruction) { diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index 0f375ab4302f4a..0846593156003f 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -2052,6 +2052,14 @@ class HloInstruction { // SetAndSanitizeName(). void UniquifyName(NameUniquer* name_uniquer); + // Use the `module`'s name uniquer to select a unique name for this + // instruction based on the instruction's existing name. + void UniquifyName(HloModule* module); + + // Use the `module`s `NewUniqueInstructionId` to set the id of this + // instruction. + void UniquifyId(HloModule* module); + // Clear the unique ID of the instruction so that it can be re-assigned, such // as for the purpose of compacting the instruction unique IDs. void ClearUniqueIdInternal() { unique_id_ = -1; } diff --git a/xla/hlo/ir/hlo_module.h b/xla/hlo/ir/hlo_module.h index 988bef8a0f63e2..6dea7d5234fc8b 100644 --- a/xla/hlo/ir/hlo_module.h +++ b/xla/hlo/ir/hlo_module.h @@ -28,6 +28,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -37,6 +38,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module_metadata.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/iterator_util.h" #include "xla/printer.h" @@ -490,6 +492,9 @@ class HloModule { // Returns the NameUniquer for uniquing instruction names in this module. NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } + // Returns the NameUniquer for uniquing computation names in this module. + NameUniquer& computation_name_uniquer() { return computation_name_uniquer_; } + // Assign a new unique dense id for an instruction int NewUniqueInstructionId() { int result = next_unique_id_; @@ -536,6 +541,13 @@ class HloModule { const HloSchedule& schedule() const { return *schedule_; } HloSchedule& schedule() { return *schedule_; } + HloComputation* AddComputation(std::unique_ptr computation, + bool is_entry) { + return AddComputationInternal(std::move(computation), is_entry, + /*uniquify_identifiers=*/false, + /*preserve_entry_layouts=*/true); + } + HloComputation* AddComputationAndUnifyNamesAndIds( std::unique_ptr computation, bool is_entry) { computation->ClearUniqueIdInternal(); diff --git a/xla/service/gpu/command_buffer_scheduling.cc b/xla/service/gpu/command_buffer_scheduling.cc index 68101e3b6a796e..793beea9317a84 100644 --- a/xla/service/gpu/command_buffer_scheduling.cc +++ b/xla/service/gpu/command_buffer_scheduling.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -465,7 +466,7 @@ absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( //===----------------------------------------------------------------------===// absl::StatusOr CommandBufferScheduling::PrepareCommandBuffer( - const HloInstructionSequence& seq) { + const HloInstructionSequence& seq, HloModule* module) { auto builder = HloComputation::Builder("command_buffer"); absl::Span instructions = @@ -507,9 +508,12 @@ absl::StatusOr CommandBufferScheduling::PrepareCommandBuffer( // Create a new parameter for value defined outside of a command buffer. int64_t parameter_id = parameters.size(); - auto* parameter = Cast(builder.AddInstruction( - HloInstruction::CreateParameter(parameter_id, operand->shape(), - absl::StrCat("p", parameter_id)))); + auto* parameter = Cast( + builder.AddInstruction(HloInstruction::CreateParameter( + parameter_id, operand->shape(), "p"))); + + parameter->UniquifyName(module); + parameter->UniquifyId(module); inst_mapping[operand] = parameters[operand] = parameter; } } @@ -532,6 +536,7 @@ absl::StatusOr CommandBufferScheduling::PrepareCommandBuffer( inst_mapping[inst] = builder.AddInstruction( inst->CloneWithNewOperands(inst->shape(), mapped_operands(inst), &ctx)); + inst_mapping[inst]->UniquifyId(module); } // Convert parameters to command buffer arguments. @@ -560,11 +565,18 @@ absl::StatusOr CommandBufferScheduling::PrepareCommandBuffer( // If we return multiple results wrap them into tuple. if (returned.size() > 1) { - builder.AddInstruction(HloInstruction::CreateTuple(returned)); + HloInstruction* inst = + builder.AddInstruction(HloInstruction::CreateTuple(returned)); + inst->UniquifyName(module); + inst->UniquifyId(module); } + std::unique_ptr comp = builder.Build(); + comp->UniquifyName(module); + comp->SetUniqueId(comp->root_instruction()->unique_id()); + return CommandBuffer{std::move(arguments), std::move(results), - builder.Build(), std::move(inst_mapping)}; + std::move(comp), std::move(inst_mapping)}; } //===----------------------------------------------------------------------===// @@ -593,9 +605,8 @@ absl::StatusOr CommandBufferScheduling::RewriteCommandBuffer( } HloComputation* computation = - parent->parent()->AddComputationAndUnifyNamesAndIds( - std::move(command_buffer.computation), - /*is_entry=*/false); + parent->parent()->AddComputation(std::move(command_buffer.computation), + /*is_entry=*/false); HloInstruction* call = parent->AddInstruction(HloInstruction::CreateCall( cmd_buffer_result_shape, command_buffer.arguments, computation)); @@ -779,7 +790,7 @@ absl::StatusOr CommandBufferScheduling::Run( for (const HloInstructionSequence& seq : sequences) { TF_ASSIGN_OR_RETURN(CommandBuffer command_buffer, - PrepareCommandBuffer(seq)); + PrepareCommandBuffer(seq, comp->parent())); TF_ASSIGN_OR_RETURN( HloComputation * command_buffer_computation, RewriteCommandBuffer(comp, seq, std::move(command_buffer))); diff --git a/xla/service/gpu/command_buffer_scheduling.h b/xla/service/gpu/command_buffer_scheduling.h index e5c9ed2299648d..78590a80359e9e 100644 --- a/xla/service/gpu/command_buffer_scheduling.h +++ b/xla/service/gpu/command_buffer_scheduling.h @@ -120,7 +120,7 @@ class CommandBufferScheduling : public HloModulePass { // parameters. Results of instructions in the sequence are returned in a tuple // (if command buffer has a single result we don't wrap it into tuple). static absl::StatusOr PrepareCommandBuffer( - const HloInstructionSequence& seq); + const HloInstructionSequence& seq, HloModule* module); // Rewrites prepared command buffer computation into Hlo operations in the // parent computation (calls command buffer and replaced all users). diff --git a/xla/service/gpu/command_buffer_scheduling_test.cc b/xla/service/gpu/command_buffer_scheduling_test.cc index 056b4e39aa46e2..bda31a05980b19 100644 --- a/xla/service/gpu/command_buffer_scheduling_test.cc +++ b/xla/service/gpu/command_buffer_scheduling_test.cc @@ -88,9 +88,9 @@ TEST_F(CommandBufferSchedulingTest, SingleCommandBuffer) { // CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) { // CHECK: %[[P0]] = s32[] parameter(0) // CHECK: %[[P1]] = s32[] parameter(1) -// CHECK: %fusion.2 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation -// CHECK: %fusion.3 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1 -// CHECK: ROOT %tuple = (s32[], s32[]) tuple(%fusion.2, %fusion.3) +// CHECK: %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation +// CHECK: %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1 +// CHECK: ROOT %tuple = (s32[], s32[]) tuple(%fusion, %fusion.1) // CHECK: } // // CHECK: ENTRY %main (a: s32[], b: s32[]) -> s32[] { @@ -162,7 +162,7 @@ TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) { // CHECK: ROOT {{.*}} = s32[] fusion(%[[F0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1 // CHECK: } -// CHECK: %command_buffer.1 ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] { +// CHECK: %command_buffer.2 ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] { // CHECK: %[[P0]] = s32[] parameter(0) // CHECK: %[[P1]] = s32[] parameter(1) // CHECK: %[[F2:.+]] = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.2 @@ -176,7 +176,7 @@ TEST_F(CommandBufferSchedulingTest, MultipleCommandBuffers) { // CHECK: %[[CMD0:.+]] = s32[] call(%a, %b, %c), to_apply=%command_buffer // CHECK: %e = s32[] get-tuple-element(%c), index=1 // CHECK: %[[CALL:.+]] = s32[] custom-call(%[[CMD0]], %e), custom_call_target="some target" -// CHECK: %[[CMD1:.+]] = s32[] call(%[[CALL]], %a), to_apply=%command_buffer.1 +// CHECK: %[[CMD1:.+]] = s32[] call(%[[CALL]], %a), to_apply=%command_buffer.2 // CHECK: ROOT {{.*}} = s32[] custom-call(%[[CMD1]]), custom_call_target="some target" // CHECK: })"; @@ -431,8 +431,8 @@ TEST_F(CommandBufferSchedulingTest, DoNotCaptureUnmatchedAsyncDone) { CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> s32[] { CHECK: %[[P0]] = s32[] parameter(0) CHECK: %[[P1]] = s32[] parameter(1) - CHECK: %fusion.2 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation - CHECK: ROOT %fusion.3 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1 + CHECK: %fusion = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation + CHECK: ROOT %fusion.1 = s32[] fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation.1 CHECK: } CHECK: ENTRY %main (a: s32[4], b: s32[]) -> s32[] { @@ -577,7 +577,7 @@ TEST_F(CommandBufferSchedulingTest, PrepareCommandBuffer) { %fused_computation(param_0: s32[], param_1: s32[]) -> (s32[], s32[]) { %p0 = s32[] parameter(0) %p1 = s32[] parameter(1) - ROOT %tuple = (s32[], s32[]) tuple(s32[] %p0, s32[] %p1) + ROOT %tuple.1 = (s32[], s32[]) tuple(s32[] %p0, s32[] %p1) } %fused_computation.1(param_0: s32[], param_1: s32[]) -> s32[] { @@ -609,19 +609,20 @@ TEST_F(CommandBufferSchedulingTest, PrepareCommandBuffer) { instructions.push_back(inst); } - TF_ASSERT_OK_AND_ASSIGN(CommandBuffer command_buffer, - CommandBufferScheduling::PrepareCommandBuffer(seq)); - HloComputation* computation = module->AddComputationAndUnifyNamesAndIds( - std::move(command_buffer.computation), false); + TF_ASSERT_OK_AND_ASSIGN( + CommandBuffer command_buffer, + CommandBufferScheduling::PrepareCommandBuffer(seq, module.get())); + HloComputation* computation = module->AddComputation( + std::move(command_buffer.computation), /*is_entry=*/false); const char* expected = R"( // CHECK: %command_buffer ([[P0:.+]]: s32[], [[P1:.+]]: s32[]) -> (s32[], s32[]) { // CHECK: %[[P0]] = s32[] parameter(0) // CHECK: %[[P1]] = s32[] parameter(1) -// CHECK: %fusion.2 = (s32[], s32[]) fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation -// CHECK: %[[V0:.+]] = s32[] get-tuple-element(%fusion.2), index=0 -// CHECK: %fusion.3 = s32[] fusion(%[[P0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1 -// CHECK: ROOT {{.*}} = (s32[], s32[]) tuple(%[[V0]], %fusion.3) +// CHECK: %fusion = (s32[], s32[]) fusion(%[[P0]], %[[P1]]), kind=kLoop, calls=%fused_computation +// CHECK: %[[V0:.+]] = s32[] get-tuple-element(%fusion), index=0 +// CHECK: %fusion.1 = s32[] fusion(%[[P0]], %[[V0]]), kind=kLoop, calls=%fused_computation.1 +// CHECK: ROOT {{.*}} = (s32[], s32[]) tuple(%[[V0]], %fusion.1) // CHECK:})"; TF_ASSERT_OK_AND_ASSIGN( From 073304cbbf72153e0496c2e4d91d6fcfc847b7d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jul 2024 14:37:05 -0700 Subject: [PATCH 071/376] Removes 'e_val' from the AutoShardingSolverOutput class (these values can be determined from 's_val'). PiperOrigin-RevId: 654899742 --- .../auto_sharding/auto_sharding_solver.cc | 93 ++++------------- .../auto_sharding/auto_sharding_solver.h | 6 -- .../auto_sharding_solver_test.cc | 99 ++++--------------- 3 files changed, 38 insertions(+), 160 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index e24f09299fd83d..96dc9c1bb78fd3 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -69,7 +69,7 @@ constexpr double kMaxCostEpsilon = 1.0001; bool AutoShardingSolverOutput::operator==( const AutoShardingSolverOutput& other) const { - return s_val == other.s_val && e_val == other.e_val && cost == other.cost && + return s_val == other.s_val && cost == other.cost && peak_times == other.peak_times; } @@ -844,23 +844,6 @@ std::vector GetChosenNodeStrategy( return chosen_node_strategy; } -std::vector GetChosenEdgeStrategy( - const AutoShardingSolverRequest& request, - const std::vector>& e) { - size_t num_edges = request.edges_size(); - std::vector chosen_edge_strategy(num_edges, -1); - for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { - for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { - // if lhs == 1 - if (e[edge_idx][j]->solution_value() > 0.5) { - chosen_edge_strategy[edge_idx] = j; - break; - } - } - } - return chosen_edge_strategy; -} - AutoShardingSolverResult SolveAndExtractSolution( const AutoShardingSolverRequest& request, const std::vector>& s, @@ -944,15 +927,18 @@ AutoShardingSolverResult SolveAndExtractSolution( double unsalted_objective = 0.0; const std::vector chosen_node_strategy = GetChosenNodeStrategy(request, s); - const std::vector chosen_edge_strategy = - GetChosenEdgeStrategy(request, e); for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { const NodeStrategyIdx j = chosen_node_strategy[node_idx]; unsalted_objective += request.computation_costs(node_idx).costs(j) + request.communication_costs(node_idx).costs(j); } + const auto chosen_edge_strategy = [&](EdgeIdx edge_idx) { + const auto& edge = request.edges(edge_idx); + return chosen_node_strategy[edge.first()] * request.s_len(edge.second()) + + chosen_node_strategy[edge.second()]; + }; for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { - const EdgeStrategyIdx j = chosen_edge_strategy[edge_idx]; + const EdgeStrategyIdx j = chosen_edge_strategy(edge_idx); unsalted_objective += request.resharding_costs(edge_idx).costs(j); } if (overbudget_var) { @@ -975,7 +961,6 @@ AutoShardingSolverResult SolveAndExtractSolution( } PrintLargestInstructions(chosen_node_strategy, request); const AutoShardingSolverOutput output = {std::move(chosen_node_strategy), - std::move(chosen_edge_strategy), unsalted_objective}; return AutoShardingSolverResult(output, false); } @@ -1008,7 +993,11 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, const auto& v = request.value_costs(); const auto& p = request.departure_costs(); const std::vector& s_val = result.status->s_val; - const std::vector& e_val = result.status->e_val; + const auto e_val = [&](EdgeIdx edge_idx) { + const auto& edge = request.edges(edge_idx); + return s_val[edge.first()] * request.s_len(edge.second()) + + s_val[edge.second()]; + }; AutoShardingEvaluation evaluation; // Compute violations. for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { @@ -1033,7 +1022,7 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, } } for (EdgeIdx edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) { - if (r.at(edge_idx).costs(e_val[edge_idx]) >= kInfinityCost) { + if (r.at(edge_idx).costs(e_val(edge_idx)) >= kInfinityCost) { evaluation.violation_codes.insert(kInfiniteCostViolationCode); } } @@ -1063,7 +1052,7 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, request.enable_memory_edge_costs()) { for (EdgeIdx edge_idx : request.live_edges(time_idx).edges()) { const auto& m = request.memory_edge_costs(edge_idx).costs(); - total_memory_costs[time_idx] += m[e_val[edge_idx]]; + total_memory_costs[time_idx] += m[e_val(edge_idx)]; lower_bound_memory_costs[time_idx] += *std::min_element(m.begin(), m.end()); } @@ -1088,7 +1077,7 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, double lower_bound_group_cost = 0.0; for (const EdgeIdx edge_idx : group.prims()) { const auto& m = request.memory_edge_costs(edge_idx).costs(); - total_group_cost += m[e_val[edge_idx]]; + total_group_cost += m[e_val(edge_idx)]; lower_bound_group_cost += *std::min_element(m.begin(), m.end()); } total_edge_group_costs.push_back(total_group_cost); @@ -1132,7 +1121,7 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, double total_memory_cost = 0.0, lower_bound_memory_cost = 0.0; if (edge_idx < request.edges_size()) { const auto& m = request.memory_edge_costs(edge_idx).costs(); - total_memory_cost = m[e_val[edge_idx]]; + total_memory_cost = m[e_val(edge_idx)]; lower_bound_memory_cost = *std::min_element(m.begin(), m.end()); } else { int64_t group_idx = edge_idx - request.edges_size(); @@ -1178,7 +1167,7 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, c.at(node_idx).costs().begin(), c.at(node_idx).costs().end()); } for (EdgeIdx edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) { - evaluation.total.resharding_cost += r.at(edge_idx).costs(e_val[edge_idx]); + evaluation.total.resharding_cost += r.at(edge_idx).costs(e_val(edge_idx)); evaluation.lower_bound.resharding_cost += *std::min_element( r.at(edge_idx).costs().begin(), r.at(edge_idx).costs().end()); } @@ -1186,54 +1175,6 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, return evaluation; } -std::vector Rationalize(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result, - const AutoShardingSolverResult& subopt) { - std::vector rationales; - const auto& names = request.instruction_names(); - - const std::vector& s_result = result.status->s_val; - const std::vector& s_subopt = subopt.status->s_val; - for (NodeIdx node_idx = 0; node_idx < request.num_nodes(); ++node_idx) { - const NodeStrategyIdx j = s_result[node_idx], k = s_subopt[node_idx]; - if (j != k) { - rationales.push_back(absl::StrCat( - "strategy changes for ", names[node_idx], " (", j, " -> ", k, ")")); - } - const double dj = request.communication_costs(node_idx).costs(j); - const double dk = request.communication_costs(node_idx).costs(k); - if (dj < dk) { - rationales.push_back(absl::StrCat("communication cost increases for ", - names[node_idx], " (", dj, " -> ", dk, - ")")); - } - const double cj = request.computation_costs(node_idx).costs(j); - const double ck = request.computation_costs(node_idx).costs(k); - if (cj < ck) { - rationales.push_back(absl::StrCat("computation cost increases for ", - names[node_idx], " (", cj, " -> ", ck, - ")")); - } - } - - const std::vector& e_result = result.status->e_val; - const std::vector& e_subopt = subopt.status->e_val; - for (EdgeIdx edge_idx = 0; edge_idx < request.edges_size(); ++edge_idx) { - const auto& edge = request.edges(edge_idx); - const EdgeStrategyIdx j = e_result[edge_idx], k = e_subopt[edge_idx]; - const double rj = request.resharding_costs(edge_idx).costs(j); - const double rk = request.resharding_costs(edge_idx).costs(k); - if (rj < rk) { - const std::string edge_name = - absl::StrCat(names[edge.first()], " and ", names[edge.second()]); - rationales.push_back(absl::StrCat("resharding cost increases for ", - edge_name, " (", rj, " -> ", rk, ")")); - } - } - - return rationales; -} - absl::Status ValidateRequest(const AutoShardingSolverRequest& request) { const int num_nodes = request.num_nodes(); const int num_edges = request.edges_size(); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index ff3f1bdcf57c98..cb051f7718fd44 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -31,7 +31,6 @@ namespace spmd { struct AutoShardingSolverOutput { std::vector s_val; - std::vector e_val; double cost = -1.0; absl::flat_hash_set peak_times; @@ -95,11 +94,6 @@ struct AutoShardingEvaluation { AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, const AutoShardingSolverResult& result); -// Produces a list of rationales for why an alternate result may be suboptimal. -std::vector Rationalize(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result, - const AutoShardingSolverResult& subopt); - // Creates and returns a variable for makespan. operations_research::MPVariable* CreateMakespanVar( const AutoShardingSolverRequest& request, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 3bb68a13895d01..16a62c0123c771 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -249,10 +249,8 @@ TEST(CallORToolsSolverTest, SolvesOptimally) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 0, 0, 0}; - const std::vector e_val = {0, 0}; const double objective_value = 7650.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -265,10 +263,8 @@ TEST(CallORToolsSolverTest, SolvesOverbudget) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 0, 0, 0}; - const std::vector e_val = {0, 0}; const double objective_value = 9007650.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -280,10 +276,8 @@ TEST(CallORToolsSolverTest, SolvesMaxDepartures) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 1, 1, 0}; - const std::vector e_val = {1, 1}; const double objective_value = 7872.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -297,10 +291,8 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {3, 0, 0, 0, 0}; - const std::vector e_val = {12, 0}; const double objective_value = 10683.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -312,10 +304,8 @@ TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 1, 1, 0}; - const std::vector e_val = {1, 1}; const double objective_value = 7872.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -339,10 +329,8 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 0, 0, 0}; - const std::vector e_val = {0, 0, 0}; const double objective_value = 12650.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -355,10 +343,8 @@ TEST(CallORToolsSolverTest, UsesHint) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 0, 0, 0}; - const std::vector e_val = {0, 0}; const double objective_value = 7650.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -389,10 +375,8 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 1, 1, 0}; - const std::vector e_val = {1, 1}; const double objective_value = 7872.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -419,10 +403,8 @@ TEST(CallORToolsSolverTest, HandlesIntervals) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 1, 1, 0}; - const std::vector e_val = {1, 1}; const double objective_value = 7872.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -453,10 +435,8 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 1, 1, 0}; - const std::vector e_val = {1, 1}; const double objective_value = 7872.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -468,10 +448,8 @@ TEST(CallORToolsSolverTest, SolvesWithEquivalences) { const AutoShardingSolverResult result = CallORToolsSolver(request); const std::vector s_val = {0, 0, 5, 5, 1}; - const std::vector e_val = {5, 5}; const double objective_value = 7650.0; - const AutoShardingSolverOutput expected_output = - {s_val, e_val, objective_value}; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; const AutoShardingSolverResult expected_result = {expected_output, false}; EXPECT_EQ(result, expected_result); } @@ -479,9 +457,8 @@ TEST(CallORToolsSolverTest, SolvesWithEquivalences) { TEST(AutoShardingEvaluatorTest, NoViolations) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector s_val = {3, 1, 2, 2, 1}; - const std::vector e_val = {14, 6}; const double objective_value = 12149.0; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -502,9 +479,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) { request.set_memory_budget(100000); request.mutable_overbudget_coeff()->set_coeff(10.0); const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; - const std::vector e_val = {10, 6}; const double objective_value = 11138.0; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -531,9 +507,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudgetWithIntervals) { request.clear_live(); AddIntervals(request.mutable_node_intervals(), node_intervals); const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; - const std::vector e_val = {10, 6}; const double objective_value = 11138.0; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -563,9 +538,8 @@ TEST(AutoShardingEvaluatorTest, AddIntervals(request.mutable_node_intervals(), node_intervals); AddGroups(request.mutable_node_groups(), node_groups); const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; - const std::vector e_val = {10, 6}; const double objective_value = 11138.0; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -586,9 +560,8 @@ TEST(AutoShardingEvaluatorTest, TEST(AutoShardingEvaluatorTest, ViolatesFollower) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector s_val = {3, 1, 2, 1 /* violates */, 1}; - const std::vector e_val = {14, 6}; const double objective_value = 12138.0; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -608,9 +581,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesFollower) { TEST(AutoShardingEvaluatorTest, ViolatesAlias) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector s_val = {3, 1, 2, 2, 0 /* violates */}; - const std::vector e_val = {14, 6}; const double objective_value = 12138.0; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -630,9 +602,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesAlias) { TEST(AutoShardingEvaluatorTest, ViolatesMemory) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; - const std::vector e_val = {10, 6}; const double objective_value = 11138.0; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -655,9 +626,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) { request.mutable_computation_costs(0)->set_costs(1, kInfinityCost); request.mutable_computation_costs(0)->set_costs(2, kInfinityCost); const std::vector s_val = {0 /* violates */, 1, 2, 2, 1}; - const std::vector e_val = {2, 6}; const double objective_value = 1e+20; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -678,9 +648,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_resharding_costs(0)->set_costs(2, kInfinityCost); const std::vector s_val = {0, 1, 2, 2, 1}; - const std::vector e_val = {2 /* violates */, 6}; const double objective_value = 1e+20; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -701,9 +670,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_departures()->set_coeff(2.0); const std::vector s_val = {3, 1, 2, 2, 1}; - const std::vector e_val = {14, 6}; const double objective_value = 12149.0; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; + const AutoShardingSolverOutput output = {s_val, objective_value}; const AutoShardingSolverResult result = {output, false}; const AutoShardingEvaluation evaluation = Evaluate(request, result); @@ -720,31 +688,6 @@ TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { EXPECT_EQ(evaluation, expected_evaluation); } -TEST(AutoShardingRationalizerTest, RationalizesProperly) { - const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - const std::vector s_val = {0, 1, 2, 2, 1}; - const std::vector e_val = {2, 6}; - const double objective_value = 9116.0; - const AutoShardingSolverOutput output = {s_val, e_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const std::vector s_subopt = {3, 1, 2, 2, 1}; - const std::vector e_subopt = {14, 6}; - const double subopt_value = 12149.0; - const AutoShardingSolverOutput subopt_output = - {s_subopt, e_subopt, subopt_value}; - const AutoShardingSolverResult subopt_result = {subopt_output, false}; - - const std::vector rationales = - Rationalize(request, result, subopt_result); - - const std::vector expected_rationales = { - "strategy changes for A (0 -> 3)", - "communication cost increases for A (100 -> 130)", - "computation cost increases for A (10 -> 13)", - "resharding cost increases for A and C (1200 -> 4200)"}; - EXPECT_EQ(rationales, expected_rationales); -} - TEST(ScaleRequest, ScalesProperly) { AutoShardingSolverRequest unscaled_request; const CostMatrix c = {{10000000, 11000000, 12000000, 13000000}, From 732931e78b1c304567dc6b613b9eab1e7fa0c624 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 22 Jul 2024 14:53:46 -0700 Subject: [PATCH 072/376] Integrate StableHLO at openxla/stablehlo@840c41ce PiperOrigin-RevId: 654905025 --- third_party/stablehlo/temporary.patch | 101 -------------------------- third_party/stablehlo/workspace.bzl | 4 +- 2 files changed, 2 insertions(+), 103 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index a6da7f82c42f12..8b137891791fe9 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,102 +1 @@ -diff --ruN a/stablehlo/stablehlo/reference/Tensor.cpp b/stablehlo/stablehlo/reference/Tensor.cpp ---- stablehlo/stablehlo/reference/Tensor.cpp -+++ stablehlo/stablehlo/reference/Tensor.cpp -@@ -423,7 +423,7 @@ - getType().print(os); - os << " {"; - Index idx{}; -- printHelper(os, *this, getShape(), idx, /*index=*/1); -+ printHelper(os, *this, getShape(), idx, /*indent=*/1); - os << "}"; - } - -diff --ruN a/stablehlo/stablehlo/tests/math/acos_limits.mlir b/stablehlo/stablehlo/tests/math/acos_limits.mlir ---- stablehlo/stablehlo/tests/math/acos_limits.mlir -+++ stablehlo/stablehlo/tests/math/acos_limits.mlir -@@ -0,0 +1,14 @@ -+// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret -+ -+func.func @main() -> (tensor, tensor>) { -+ %cst = stablehlo.constant dense<-1.000000e+00> : tensor -+ %cst_0 = stablehlo.constant dense<(-1.000000e+00,0.000000e+00)> : tensor> -+ %zero = stablehlo.constant dense<0.0> : tensor -+ %pi = stablehlo.constant dense<3.1415926535897931> : tensor -+ %complex_pi = stablehlo.complex %pi, %zero : tensor> -+ %0 = chlo.acos %cst : tensor -> tensor -+ %1 = chlo.acos %cst_0 : tensor> -> tensor> -+ check.expect_close %0, %pi, max_ulp_difference = 1 : tensor, tensor -+ check.expect_close %1, %complex_pi, max_ulp_difference = 1 : tensor>, tensor> -+ return %0, %1 : tensor, tensor> -+} -diff --ruN a/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td b/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td ---- stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td -+++ stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td -@@ -45,6 +45,37 @@ - //===----------------------------------------------------------------------===// - // Unary op patterns. - //===----------------------------------------------------------------------===// -+ -+// Expand acos for non-complex arguments to MHLO dialect as follows: -+// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 -+// = pi if x == -1 -+// -+// Note: Complex decomposition is in ChloDecompositionPatternsMath.td -+def : Pat<(CHLO_AcosOp NonComplexElementType:$input), -+ (StableHLO_SelectOp -+ (StableHLO_CompareOp -+ $input, -+ (StableHLO_ConstantLike<"-1"> $input), -+ StableHLO_ComparisonDirectionValue<"NE">, -+ (STABLEHLO_DEFAULT_COMPARISON_TYPE) -+ ), -+ (StableHLO_MulOp -+ (StableHLO_ConstantLike<"2"> $input), -+ (StableHLO_Atan2Op -+ (StableHLO_SqrtOp -+ (StableHLO_SubtractOp -+ (StableHLO_ConstantLike<"1"> $input), -+ (StableHLO_MulOp $input, $input) -+ ) -+ ), -+ (StableHLO_AddOp -+ (StableHLO_ConstantLike<"1"> $input), -+ $input -+ ) -+ ) -+ ), -+ (StableHLO_ConstantLike<"M_PI"> $input) -+ )>; - - // Express `atan` as - // atan(x) = atan2(x, 1) -diff --ruN a/stablehlo/stablehlo/transforms/ChloDecompositionPatternsMath.td b/stablehlo/stablehlo/transforms/ChloDecompositionPatternsMath.td ---- stablehlo/stablehlo/transforms/ChloDecompositionPatternsMath.td -+++ stablehlo/stablehlo/transforms/ChloDecompositionPatternsMath.td -@@ -634,26 +634,6 @@ - (StableHLO_Log1pOp - (StableHLO_AddOp $am1, $sq)))), - (StableHLO_NegOp $imag)))>; -- --// Arcus cosine on real input: --// --// arccos(x) = 2 * arctan2(sqrt(1 - x * x), 1 + x) --// --// To avoid cancellation errors at abs(x) close to 1, we'll use --// --// 1 - x * x == (1 - x) * (1 + x) --// --def : Pat<(CHLO_AcosOp NonComplexElementType:$x), -- (StableHLO_MulOp -- (StableHLO_ConstantLike<"2"> $x), -- (StableHLO_Atan2Op -- (StableHLO_SqrtOp -- (StableHLO_MulOp -- (StableHLO_SubtractOp -- (StableHLO_ConstantLike<"1">:$one $x), -- $x), -- (StableHLO_AddOp:$add_one_x $one, $x))), -- $add_one_x))>; - - // Inverse hyperbolic cosine on complex input: - // diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 48b2b101ae8a1a..03e69998dfc9ee 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "531816f07e0db010a676c23fc66fe0a1a2e2d648" - STABLEHLO_SHA256 = "5a0b6a4dbe739793f1c4ea7d117aac81edaa18e2f2fe795fc3ffe6a2e9be2ac8" + STABLEHLO_COMMIT = "840c41ceb0d13800d286a9d76d8ad00d97838d9e" + STABLEHLO_SHA256 = "f2f92695ecdb2449a3d2316015a37301c1e4768315b9e753e18b4759eebb67e8" # LINT.ThenChange(Google-internal path) tf_http_archive( From e0952ac3a511f1df8e0da0c8b2ccd1a7865a3ccf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jul 2024 15:50:44 -0700 Subject: [PATCH 073/376] Simplifies the handling of edge strategy variables. PiperOrigin-RevId: 654923835 --- .../auto_sharding/auto_sharding_solver.cc | 43 +++---------------- 1 file changed, 7 insertions(+), 36 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 96dc9c1bb78fd3..f204ff43496d61 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -368,11 +368,7 @@ void AddMemoryTerms( // d. For all (i, j) in E, e[i, j] in {0, 1} ^ dim(e[i, j]) // e. For all (i, j) in E, e[i, j]^T * 1 == 1 // Make sure s[i] and s[j] align with e[i, j]: -// f. For all (i, j) in E, 0 <= p < dim(s[i]), -// sum_{0 <= q < dim(s[j])} e[i, j](p * dim(s[j]) + q) <= s[i](p) -// g. For all (i, j) in E, 0 <= q < dim(s[j]), -// sum_{0 <= p < dim(s[i])} e[i, j](p * dim(s[j]) + q) <= s[j](q) -// h. For all (i, j) in A and all (p, q), +// f. For all (i, j) in A and all (p, q), // s[i][p] + s[j][q] <= 1 if v[p, q] == 1.0 // Serialize parameters of the ILP problem as numpy arrays and call the python // solver. @@ -637,47 +633,22 @@ AutoShardingSolverResult CallORToolsSolver( // d. specified via "BoolVarArray" // e. - for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { - if (e_follow[edge_idx] >= 0) continue; - const auto& edge = request.edges(edge_idx); - MPConstraint* constraint = solver->MakeRowConstraint( - 1.0, 1.0, - absl::StrCat("sum(e[", edge.first(), "][", edge.second(), "][*]) = 1")); - for (EdgeStrategyIdx j = 0; j < e[edge_idx].size(); ++j) { - constraint->SetCoefficient(e[edge_idx][j], 1.0); - } - } - // f. for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { if (e_follow[edge_idx] >= 0) continue; const auto& edge = request.edges(edge_idx); for (NodeStrategyIdx p = 0; p < s[edge.first()].size(); ++p) { - MPConstraint* constraint = solver->MakeRowConstraint( - -MPSolver::infinity(), 0, - absl::StrCat("f for i = ", edge_idx, ", p = ", p)); - constraint->SetCoefficient(s[edge.first()][p], -1.0); for (NodeStrategyIdx q = 0; q < s[edge.second()].size(); ++q) { const EdgeStrategyIdx j = p * s[edge.second()].size() + q; + MPConstraint* constraint = solver->MakeRowConstraint( + -1.0, MPSolver::infinity(), + absl::StrCat("edge[", edge_idx, "][", j, "]")); + constraint->SetCoefficient(s[edge.first()][p], -1.0); + constraint->SetCoefficient(s[edge.second()][q], -1.0); constraint->SetCoefficient(e[edge_idx][j], 1.0); } } } - // g. - for (EdgeIdx edge_idx = 0; edge_idx < num_edges; ++edge_idx) { - if (e_follow[edge_idx] >= 0) continue; - const auto& edge = request.edges(edge_idx); - for (NodeStrategyIdx q = 0; q < s[edge.second()].size(); ++q) { - MPConstraint* constraint = solver->MakeRowConstraint( - -MPSolver::infinity(), 0, - absl::StrCat("g for i = ", edge_idx, ", q = ", q)); - constraint->SetCoefficient(s[edge.second()][q], -1.0); - for (NodeStrategyIdx p = 0; p < s[edge.first()].size(); ++p) { - const EdgeStrategyIdx j = p * s[edge.second()].size() + q; - constraint->SetCoefficient(e[edge_idx][j], 1.0); - } - } - } - // h. + // f. absl::flat_hash_set> alias_set; for (auto alias_idx = 0; alias_idx < request.aliases_size(); ++alias_idx) { const auto& raw_alias = request.aliases(alias_idx); From eb9a1b53099d9caebacb82b3a4d1edae2525fab6 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 22 Jul 2024 16:00:26 -0700 Subject: [PATCH 074/376] Use GpuContext in GpuEvent and derived classes rather than GpuExecutor. The only thing GpuExecutor was used for was to get the GpuContext. PiperOrigin-RevId: 654926533 --- xla/stream_executor/cuda/BUILD | 1 - xla/stream_executor/cuda/cuda_event.cc | 5 +---- xla/stream_executor/cuda/cuda_event.h | 5 +++-- xla/stream_executor/cuda/cuda_executor.cc | 2 +- xla/stream_executor/gpu/gpu_event.cc | 17 +++++++---------- xla/stream_executor/gpu/gpu_event.h | 9 +++++---- xla/stream_executor/rocm/BUILD | 1 - xla/stream_executor/rocm/rocm_event.cc | 4 +--- xla/stream_executor/rocm/rocm_event.h | 5 +++-- xla/stream_executor/rocm/rocm_executor.cc | 2 +- 10 files changed, 22 insertions(+), 29 deletions(-) diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 161c13f2641bb9..96b85287ef5f49 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -492,7 +492,6 @@ cuda_only_cc_library( ":cuda_driver", "//xla/stream_executor:event", "//xla/stream_executor/gpu:gpu_event", - "//xla/stream_executor/gpu:gpu_executor_header", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@local_config_cuda//cuda:cuda_headers", diff --git a/xla/stream_executor/cuda/cuda_event.cc b/xla/stream_executor/cuda/cuda_event.cc index 2bc31791c4fc50..c9c5ee79c4fc40 100644 --- a/xla/stream_executor/cuda/cuda_event.cc +++ b/xla/stream_executor/cuda/cuda_event.cc @@ -20,15 +20,12 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_driver.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_executor.h" namespace stream_executor { namespace gpu { Event::Status CudaEvent::PollForStatus() { - absl::StatusOr status = - QueryEvent(parent()->gpu_context(), gpu_event()); + absl::StatusOr status = QueryEvent(context(), gpu_event()); if (!status.ok()) { LOG(ERROR) << "Error polling for event status: " << status.status().message(); diff --git a/xla/stream_executor/cuda/cuda_event.h b/xla/stream_executor/cuda/cuda_event.h index 3d91064b817c47..3115ded266c784 100644 --- a/xla/stream_executor/cuda/cuda_event.h +++ b/xla/stream_executor/cuda/cuda_event.h @@ -18,14 +18,15 @@ limitations under the License. #include "xla/stream_executor/event.h" #include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_executor.h" namespace stream_executor::gpu { +class GpuContext; + // This class implements Event::PollForStatus for CUDA devices. class CudaEvent : public GpuEvent { public: - explicit CudaEvent(GpuExecutor *executor) : GpuEvent(executor) {} + explicit CudaEvent(GpuContext *context) : GpuEvent(context) {} Event::Status PollForStatus() override; }; diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index c05a9189b21d60..d0fb598d9fe8ed 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -722,7 +722,7 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device, absl::StatusOr> GpuExecutor::CreateGpuEvent( bool allow_timing) { - auto gpu_event = std::make_unique(this); + auto gpu_event = std::make_unique(gpu_context()); TF_RETURN_IF_ERROR(gpu_event->Init(allow_timing)); return std::move(gpu_event); } diff --git a/xla/stream_executor/gpu/gpu_event.cc b/xla/stream_executor/gpu/gpu_event.cc index e2aec087ef49e0..036d38f4271e9a 100644 --- a/xla/stream_executor/gpu/gpu_event.cc +++ b/xla/stream_executor/gpu/gpu_event.cc @@ -20,39 +20,36 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/status/status.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" namespace stream_executor { namespace gpu { -GpuEvent::GpuEvent(GpuExecutor* parent) - : parent_(parent), gpu_event_(nullptr) {} +GpuEvent::GpuEvent(GpuContext* context) + : context_(context), gpu_event_(nullptr) {} GpuEvent::~GpuEvent() { Destroy().IgnoreError(); } absl::Status GpuEvent::Init(bool allow_timing) { - return GpuDriver::InitEvent(parent_->gpu_context(), &gpu_event_, + return GpuDriver::InitEvent(context_, &gpu_event_, allow_timing ? GpuDriver::EventFlags::kDefault : GpuDriver::EventFlags::kDisableTiming); } absl::Status GpuEvent::Destroy() { - return GpuDriver::DestroyEvent(parent_->gpu_context(), &gpu_event_); + return GpuDriver::DestroyEvent(context_, &gpu_event_); } absl::Status GpuEvent::Record(GpuStreamHandle stream_handle) { - return GpuDriver::RecordEvent(parent_->gpu_context(), gpu_event_, - stream_handle); + return GpuDriver::RecordEvent(context_, gpu_event_, stream_handle); } GpuEventHandle GpuEvent::gpu_event() { return gpu_event_; } absl::Status GpuEvent::WaitForEventOnExternalStream(std::intptr_t stream) { - if (GpuDriver::WaitStreamOnEvent(parent_->gpu_context(), - absl::bit_cast(stream), - gpu_event_)) { + if (GpuDriver::WaitStreamOnEvent( + context_, absl::bit_cast(stream), gpu_event_)) { return absl::OkStatus(); } else { return absl::InternalError("Error waiting for event on external stream"); diff --git a/xla/stream_executor/gpu/gpu_event.h b/xla/stream_executor/gpu/gpu_event.h index 66b9bbafd7abd4..1620fceadfd0b1 100644 --- a/xla/stream_executor/gpu/gpu_event.h +++ b/xla/stream_executor/gpu/gpu_event.h @@ -25,11 +25,12 @@ limitations under the License. namespace stream_executor { namespace gpu { -class GpuExecutor; +class GpuContext; + // GpuEvent wraps a GpuEventHandle in the platform-independent Event interface. class GpuEvent : public Event { public: - explicit GpuEvent(GpuExecutor* parent); + explicit GpuEvent(GpuContext* context); ~GpuEvent() override; @@ -49,11 +50,11 @@ class GpuEvent : public Event { absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override; protected: - GpuExecutor* parent() const { return parent_; } + GpuContext* context() const { return context_; } private: // The Executor used to which this object and GpuEventHandle are bound. - GpuExecutor* parent_; + GpuContext* context_; // The underlying CUDA event element. GpuEventHandle gpu_event_; diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index fc6af2700c1f7f..8a3550bad3a1f6 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -148,7 +148,6 @@ cc_library( ":rocm_driver", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_event_header", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream_header", ]), ) diff --git a/xla/stream_executor/rocm/rocm_event.cc b/xla/stream_executor/rocm/rocm_event.cc index 24ebb3470ea747..a1d719e825f699 100644 --- a/xla/stream_executor/rocm/rocm_event.cc +++ b/xla/stream_executor/rocm/rocm_event.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/rocm/rocm_driver.h" @@ -24,8 +23,7 @@ namespace stream_executor { namespace gpu { Event::Status RocmEvent::PollForStatus() { - absl::StatusOr status = - QueryEvent(parent()->gpu_context(), gpu_event()); + absl::StatusOr status = QueryEvent(context(), gpu_event()); if (!status.ok()) { LOG(ERROR) << "Error polling for event status: " << status.status().message(); diff --git a/xla/stream_executor/rocm/rocm_event.h b/xla/stream_executor/rocm/rocm_event.h index 7652df7acc5c47..980f2804add01c 100644 --- a/xla/stream_executor/rocm/rocm_event.h +++ b/xla/stream_executor/rocm/rocm_event.h @@ -17,14 +17,15 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_ROCM_ROCM_EVENT_H_ #include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_executor.h" namespace stream_executor::gpu { +class GpuContest; + // This class implements Event::PollForStatus for ROCm devices. class RocmEvent : public GpuEvent { public: - explicit RocmEvent(GpuExecutor *executor) : GpuEvent(executor) {} + explicit RocmEvent(GpuContext *context) : GpuEvent(context) {} Event::Status PollForStatus() override; }; diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 07dc7b69f8f3f7..6e46312828140c 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -635,7 +635,7 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device, absl::StatusOr> GpuExecutor::CreateGpuEvent( bool allow_timing) { - auto gpu_event = std::make_unique(this); + auto gpu_event = std::make_unique(gpu_context()); TF_RETURN_IF_ERROR(gpu_event->Init(allow_timing)); return std::move(gpu_event); } From ed5d24f39a11dd8afe4c737449849f40b8556de3 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 22 Jul 2024 16:12:37 -0700 Subject: [PATCH 075/376] [xla:cpu] Use InProcess collectives if run options do not provide an override Make thunk implementation consistent with current XLA PiperOrigin-RevId: 654930441 --- xla/service/cpu/runtime/BUILD | 3 +++ xla/service/cpu/runtime/thunk.cc | 16 +++++++++++----- xla/service/cpu/runtime/thunk_test.cc | 26 +++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/xla/service/cpu/runtime/BUILD b/xla/service/cpu/runtime/BUILD index d48cc0271093a9..c1f1d317a9f68e 100644 --- a/xla/service/cpu/runtime/BUILD +++ b/xla/service/cpu/runtime/BUILD @@ -108,6 +108,9 @@ xla_cc_test( srcs = ["thunk_test.cc"], deps = [ ":thunk", + "//xla:executable_run_options", + "//xla/service/cpu:collectives_interface", + "//xla/service/cpu:cpu_executable_run_options", "@com_google_absl//absl/status", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", diff --git a/xla/service/cpu/runtime/thunk.cc b/xla/service/cpu/runtime/thunk.cc index a9cb0ca993eff7..7810401df56900 100644 --- a/xla/service/cpu/runtime/thunk.cc +++ b/xla/service/cpu/runtime/thunk.cc @@ -95,13 +95,19 @@ Thunk::CollectiveExecuteParams::Create( ? run_options->device_ordinal() : run_options->stream()->parent()->device_ordinal(); + // Default implementation of a collectives interface that can execute + // collective operations within the same process. + static CollectivesInterface* in_process_collectives = + new runtime::InProcessCollectives(); + // If CPU executable run options are set, use the collectives interface - // provided by the executable run options. Otherwise, use the in-process - // collectives interface. - static auto* in_process_collectives = new runtime::InProcessCollectives(); + // provided by the executable run options if it is set. Otherwise, use the + // in-process collectives interface. + const CpuExecutableRunOptions* cpu_run_options = + run_options->cpu_executable_run_options(); CollectivesInterface* collectives = - run_options->cpu_executable_run_options() - ? run_options->cpu_executable_run_options()->collectives() + cpu_run_options && cpu_run_options->collectives() + ? cpu_run_options->collectives() : in_process_collectives; return CollectiveExecuteParams{run_options->run_id(), device_ordinal, diff --git a/xla/service/cpu/runtime/thunk_test.cc b/xla/service/cpu/runtime/thunk_test.cc index ca1ee928ea5a50..b761c509a31373 100644 --- a/xla/service/cpu/runtime/thunk_test.cc +++ b/xla/service/cpu/runtime/thunk_test.cc @@ -16,8 +16,11 @@ limitations under the License. #include "xla/service/cpu/runtime/thunk.h" #include -#include +#include "xla/executable_run_options.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/cpu/cpu_executable_run_options.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla::cpu { @@ -112,5 +115,26 @@ TEST(ThunkTest, ExecuteSession) { EXPECT_EQ(session.num_workers(), 2); } +TEST(ThunkTest, CollectiveExecuteParams) { + ExecutableRunOptions run_options; + run_options.set_device_ordinal(0); + + // Collectives interface initialized with a default implementation. + TF_ASSERT_OK_AND_ASSIGN(auto params, + Thunk::CollectiveExecuteParams::Create(&run_options)); + EXPECT_NE(params.collectives, nullptr); + + // Test forwarding collectives interface from CpuExecutableRunOptions. + CpuExecutableRunOptions cpu_run_options; + cpu_run_options.set_collectives( + reinterpret_cast(0x12345678)); + run_options.set_cpu_executable_run_options(&cpu_run_options); + + TF_ASSERT_OK_AND_ASSIGN(params, + Thunk::CollectiveExecuteParams::Create(&run_options)); + EXPECT_EQ(params.collectives, + reinterpret_cast(0x12345678)); +} + } // namespace } // namespace xla::cpu From cf18419de3566dbaa078988aef8001a64777d8e2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jul 2024 16:40:47 -0700 Subject: [PATCH 076/376] [XLA] Reintroduce convert(constant) to constant rewrite with stronger conditions: do the rewrite only when the user count of the constant is 1 and use_convert_constant_folding is enabled PiperOrigin-RevId: 654938716 --- xla/service/algebraic_simplifier.cc | 16 +++++++++ xla/service/algebraic_simplifier.h | 9 +++++ xla/service/algebraic_simplifier_test.cc | 44 ++++++++++++++++++++++-- 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index c0c4fe8ba5181b..2ef7b352a6afe7 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -5199,6 +5199,22 @@ absl::Status AlgebraicSimplifierVisitor::HandleConvert( convert->mutable_operand(0)->mutable_operand(0)); } + // Try to replace convert(constant) with a constant of the right type to begin + // with. Disallow moving sub-byte types since they may not be supported for + // some ops. + HloInstruction* constant; + if (options_.use_convert_constant_folding() && + Match(convert, m::Convert(m::Constant(&constant))) && + primitive_util::BitWidth(dest_type) <= + primitive_util::BitWidth(src_type) && + constant->user_count() == 1 && primitive_util::BitWidth(dest_type) >= 8) { + TF_ASSIGN_OR_RETURN(Literal dest_literal, + constant->literal().Convert(dest_type)); + VLOG(10) << "Replacing convert(constant) with constant"; + return ReplaceWithNewInstruction( + convert, HloInstruction::CreateConstant(std::move(dest_literal))); + } + return TryRemoveUpcastAndDowncastSurroundingBinaryOp(convert); } diff --git a/xla/service/algebraic_simplifier.h b/xla/service/algebraic_simplifier.h index 1d1134998c709a..3340d248ed43cf 100644 --- a/xla/service/algebraic_simplifier.h +++ b/xla/service/algebraic_simplifier.h @@ -109,6 +109,14 @@ class AlgebraicSimplifierOptions { return associative_reordering_threshold_; } + void set_use_convert_constant_folding(bool use_convert_constant_folding) { + use_convert_constant_folding_ = use_convert_constant_folding; + } + + bool use_convert_constant_folding() const { + return use_convert_constant_folding_; + } + // Enable dot simplification on platforms where it is profitable. void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { enable_dot_strength_reduction_ = enable_dot_strength_reduction; @@ -293,6 +301,7 @@ class AlgebraicSimplifierOptions { bool minmax_propagate_nan_{true}; bool enable_unconditional_reduce_of_concat_replacement_{true}; bool use_associative_reordering_{false}; + bool use_convert_constant_folding_{false}; bool executing_on_cpu_{false}; double associative_reordering_threshold_{2.0}; Metadata metadata_; diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index 1ba5bde63420fc..c3dadde7fe93a6 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -115,7 +115,7 @@ const char* arb_sign_ops[] = {"constant(-0.0)", "select(pred0, p0, a1)"}; // clang-format on -// Test that the result of particular oprations is always non-negative +// Test that the result of particular operations is always non-negative TEST_F(AlgebraicSimplifierTest, IsNonNegative_Op) { for (const auto* op : non_neg_ops) { const auto kModuleStr = absl::StrFormat(R"( @@ -136,7 +136,7 @@ TEST_F(AlgebraicSimplifierTest, IsNonNegative_Op) { } } -// Test that the result of particular oprations might be negative +// Test that the result of particular operations might be negative TEST_F(AlgebraicSimplifierTest, IsNonNegative_Op_NegativeTestCase) { for (const auto op : arb_sign_ops) { const auto kModuleStr = absl::StrFormat(R"( @@ -11687,5 +11687,45 @@ ENTRY main.1 { HloOpcode::kParameter); } +TEST_F(AlgebraicSimplifierTest, RemoveConvertConstant) { + const std::string hlo_string = R"( + HloModule module + + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[] add(p0, p1) + } + + ENTRY test { + a = f32[32,64] parameter(0) + b = s32[] constant(0) + c = f32[] convert(b) + ROOT reduce = f32[32] reduce(a, c), + dimensions={1}, to_apply=add + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + default_options_.set_use_convert_constant_folding(true); + EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + HloInstruction* root = m->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), + m::Constant().WithShape(F32, {})))); +} + +TEST_F(AlgebraicSimplifierTest, KeepInt4ConvertConstant) { + const std::string hlo_string = R"( + HloModule module + + ENTRY test { + a = s8[] constant(0) + ROOT b = s4[] convert(a) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + default_options_.set_use_convert_constant_folding(true); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + } // namespace } // namespace xla From a8c50f73dcb2bd16c7edd711696454f3b85fb5e3 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 22 Jul 2024 17:38:03 -0700 Subject: [PATCH 077/376] Remove unused cuda_stream.h and associated unused functions. PiperOrigin-RevId: 654954786 --- xla/service/gpu/BUILD | 1 - xla/stream_executor/cuda/BUILD | 13 --------- xla/stream_executor/cuda/cuda_stream.h | 37 -------------------------- xla/stream_executor/gpu/gpu_stream.h | 5 +--- xla/xla.bzl | 1 - 5 files changed, 1 insertion(+), 56 deletions(-) delete mode 100644 xla/stream_executor/cuda/cuda_stream.h diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 2d6fb98e359e97..18eed308d08ef0 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -757,7 +757,6 @@ cc_library( ":make_batch_pointers", ]) + if_cuda_is_configured([ "//xla/stream_executor/cuda:cublas_plugin", - "//xla/stream_executor/cuda:cuda_stream", "//xla/stream_executor/cuda:cudnn_plugin", "//xla/stream_executor/cuda:cufft_plugin", "//xla/stream_executor/cuda:stream_executor_cuda", diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 96b85287ef5f49..38475a360a04bc 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -283,7 +283,6 @@ cuda_only_cc_library( ":cuda_executor", ":cuda_helpers", ":cuda_platform_id", - ":cuda_stream", "//xla:shape_util", "//xla:status_macros", "//xla:types", @@ -404,7 +403,6 @@ cuda_only_cc_library( ":cuda_driver", ":cuda_executor", ":cuda_platform_id", - ":cuda_stream", ":cudnn_frontend_helpers", "//xla/stream_executor", "//xla/stream_executor:data_type", @@ -498,17 +496,6 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( - name = "cuda_stream", - srcs = [], - hdrs = ["cuda_stream.h"], - deps = [ - "//xla/stream_executor", - "//xla/stream_executor:blas", - "//xla/stream_executor/gpu:gpu_stream", - ], -) - cc_library( name = "ptx_compiler_support", srcs = ["ptx_compiler_support.cc"], diff --git a/xla/stream_executor/cuda/cuda_stream.h b/xla/stream_executor/cuda/cuda_stream.h deleted file mode 100644 index 7e651b45d0e6fa..00000000000000 --- a/xla/stream_executor/cuda/cuda_stream.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2015 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Defines the GpuStream type - the CUDA-specific implementation of the generic -// StreamExecutor Stream interface. - -#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ -#define XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ - -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/gpu/gpu_stream.h" - -namespace stream_executor { -namespace cuda { - -using CUDAStream = gpu::GpuStream; - -inline CUDAStream* AsCUDAStream(Stream* stream) { - return gpu::AsGpuStream(stream); -} - -} // namespace cuda -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index 984084e062b636..0f18c4e10b98cd 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -94,12 +94,9 @@ class GpuStream : public StreamCommon { // into the NVIDIA library causes difficult-to-understand faults). GpuStreamHandle gpu_stream() const { DCHECK(gpu_stream_ != nullptr); - return const_cast(gpu_stream_); + return gpu_stream_; } - // TODO(timshen): Migrate away and remove this function. - GpuStreamHandle cuda_stream() const { return gpu_stream(); } - GpuExecutor* parent() const { return parent_; } absl::Status WaitFor(Stream* other) override; absl::Status WaitFor(Event* event) override; diff --git a/xla/xla.bzl b/xla/xla.bzl index 8197dd1bc459e4..912f17dffd9c41 100644 --- a/xla/xla.bzl +++ b/xla/xla.bzl @@ -57,7 +57,6 @@ _XLA_SHARED_OBJECT_SENSITIVE_DEPS = if_static(extra_deps = [], otherwise = [ "@tsl//tsl/protobuf:protos_all_cc_impl", ]) + if_cuda_is_configured([ Label("//xla/stream_executor/cuda:all_runtime"), - Label("//xla/stream_executor/cuda:cuda_stream"), Label("//xla/stream_executor/cuda:stream_executor_cuda"), ]) + if_rocm_is_configured([ Label("//xla/stream_executor/gpu:gpu_stream"), From 28aae0e0b286336081fe86fcf65e080acf5a1f7f Mon Sep 17 00:00:00 2001 From: Seher Ellis Date: Mon, 22 Jul 2024 19:03:25 -0700 Subject: [PATCH 078/376] [XLA] Support forward-sink-pipelining collectives when they are used by multiple DUS ops. PiperOrigin-RevId: 654975412 --- xla/service/collective_pipeliner.cc | 714 ++++++++++++------ xla/service/collective_pipeliner_test.cc | 640 +++++++++++++++- .../collective_pipeliner_execution_test.cc | 95 +++ 3 files changed, 1182 insertions(+), 267 deletions(-) diff --git a/xla/service/collective_pipeliner.cc b/xla/service/collective_pipeliner.cc index 8dbadb7d99fbbe..ccfacab79229ce 100644 --- a/xla/service/collective_pipeliner.cc +++ b/xla/service/collective_pipeliner.cc @@ -290,16 +290,24 @@ bool CollectSimpleDependencies(HloInstruction* i, // If this level 0 we require the unique dynamic update slice to feed directly // into the root instruction. If this is level > 1 then we require that the // unique dynamic_update slice is inserted using the index created in the -// previous level. -std::pair> +// previous level. In the kForwardSink mode, if the value to be pushed has +// multiple dynamic update slices in its user subtree, we will return all of +// those dynamic update slices along with all of the formatting ops between the +// value and the dynamic update slices. +std::pair, + std::vector> CheckStoreIntoSliceIsCompatible(HloInstruction* instr, const HloComputation* while_body, int64_t level_to_operate_on, bool multi_uses_pipelining, - HloPredicate acceptable_formatting) { + HloPredicate acceptable_formatting, + bool multi_dyn_updates = false) { + std::pair, + std::vector> + empty_pair{{}, {}}; if ((!multi_uses_pipelining && instr->user_count() != 1) || instr->operand_count() != 1 || instr->HasControlDependencies()) { - return std::make_pair(nullptr, std::vector{}); + return empty_pair; } // Set to collect instructions that have been already added. absl::flat_hash_set added_instructions; @@ -353,58 +361,54 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, } return false; }; - HloDynamicUpdateSliceInstruction* final_slice_insertion = nullptr; + absl::flat_hash_set final_slice_set; + std::vector final_slice_insertions; std::vector> stack; - absl::flat_hash_map formatting_map; stack.push_back(std::make_pair(folded_instr, 0)); - // Post order traversal to discover formatting instructions. + // Post order traversal to discover the dynamic update slices. while (!stack.empty()) { auto& data = stack.back(); - HloInstruction* instr = data.first; - if (data.second == 0 && instr != folded_instr) { - formatting_map[instr] = 0; - } - if (data.second == instr->user_count()) { + HloInstruction* inst = data.first; + if (data.second == inst->user_count()) { stack.pop_back(); continue; } - HloInstruction* next_user = instr->users()[data.second++]; + HloInstruction* next_user = inst->users()[data.second++]; if (is_final_slice_insertion(next_user)) { - if ((final_slice_insertion != nullptr && - final_slice_insertion != next_user) || - next_user->user_count() != 1 || next_user->operand(1) != instr) { - return std::make_pair(nullptr, std::vector{}); + if (next_user->user_count() != 1 || next_user->operand(1) != inst) { + return empty_pair; + } + if (final_slice_set.contains(next_user)) { + continue; + } + if (!multi_dyn_updates && !final_slice_insertions.empty()) { + return empty_pair; } - final_slice_insertion = Cast(next_user); + final_slice_insertions.push_back( + Cast(next_user)); + final_slice_set.insert(next_user); continue; } if (!is_acceptable_user(next_user)) { - return std::make_pair(nullptr, std::vector{}); + return empty_pair; } if (added_instructions.insert(next_user).second) { stack.push_back(std::make_pair(next_user, 0)); } } - if (final_slice_insertion == nullptr) { - return std::make_pair(nullptr, std::vector{}); - } - for (auto& op : formatting_map) { - for (const HloInstruction* operand : final_slice_insertion->operands()) { - if (formatting_map.count(operand)) { - ++op.second; - } - } + if (final_slice_insertions.empty()) { + return empty_pair; } stack.push_back(std::make_pair(folded_instr, 0)); added_instructions.clear(); - // Post order traversal to determine the insert instruction order. + // Post order traversal to discover the formatting ops. while (!stack.empty()) { auto& data = stack.back(); HloInstruction* instr = data.first; if (data.second == 0 && instr != folded_instr) { if (!CollectSimpleDependencies(instr, formatting_ops, added_instructions)) { - return std::make_pair(nullptr, std::vector{}); + return empty_pair; } formatting_ops.push_back(instr); } @@ -414,22 +418,13 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, } HloInstruction* next_user = instr->users()[data.second++]; if (is_final_slice_insertion(next_user)) { - if ((final_slice_insertion != nullptr && - final_slice_insertion != next_user) || - next_user->user_count() != 1 || next_user->operand(1) != instr) { - return std::make_pair(nullptr, std::vector{}); - } - final_slice_insertion = Cast(next_user); - continue; - } - if (--formatting_map[next_user] > 0) { continue; } if (added_instructions.insert(next_user).second) { stack.push_back(std::make_pair(next_user, 0)); } } - return std::make_pair(final_slice_insertion, formatting_ops); + return std::make_pair(final_slice_insertions, formatting_ops); } bool IsLoopIterator(const HloInstruction* instr, @@ -444,10 +439,11 @@ bool IsLoopIterator(const HloInstruction* instr, // Scavenge operands that are dependencies not included in the ops set and that // aren't the source_op passed as input parameter and return them in a vector. std::vector CollectDependenciesToPipeline( - const HloInstruction* source_op, absl::Span ops) { + absl::Span source_ops, + absl::Span ops) { absl::flat_hash_set formatting_set(ops.begin(), ops.end()); - formatting_set.insert(source_op); + formatting_set.insert(source_ops.begin(), source_ops.end()); std::vector to_return; absl::flat_hash_set already_inserted; for (const HloInstruction* op : ops) { @@ -605,15 +601,25 @@ std::vector MapNewOperands( return new_operands; } -// Collect information regarding movement of data either backward or forward -// through loop iterations. Except collective_to_move every other information -// here can be empty/null/-1 to indicate absence. +// Information regarding the movement of data for the pipelining directions: +// (i) kBackward: pushed to the previous iteration, +// (ii) kForward: pushed to the next iteration, and +// (iii) kForwardSink: completely pushed outside of the loop. +// collectives_to_move has only a single collective for (i) and (ii), but can +// have multiple collectives for (iii). Similarly, dynamic_update_slices can +// have multiple instructions for only (iii). In that case, sliced_idx for each +// dynamic-update-slice is the same, so we store one value to represent all of +// them. Output_indices[i] represents where dynamic_update_slices[i] is +// in the original while tuple. formatting_ops are the instructions between the +// collective(s) to be pushed and the respective dynamic-update-slice(s). Empty +// or -1 indicates the absence of the respective information. The only mandatory +// field is collectives_to_move. struct WhileMoveInfo { - HloInstruction* collective_to_move; - HloDynamicUpdateSliceInstruction* dynamic_update_slice; + std::vector collectives_to_move; + std::vector dynamic_update_slices; std::vector formatting_ops; int64_t sliced_idx; - int64_t output_idx; + std::vector output_indices; }; // Set channel_id of instruction to next available to avoid collisions. @@ -656,7 +662,7 @@ absl::StatusOr CloneBackwardChain( LoopVariantParameterInfo* loop_variant_parameter_info = nullptr) { std::vector to_clone(move_info.formatting_ops.begin(), move_info.formatting_ops.end()); - to_clone.push_back(move_info.collective_to_move); + to_clone.push_back(move_info.collectives_to_move[0]); HloInstruction* last_cloned = nullptr; for (auto* chain_op : to_clone) { // Do not clone a loop iterator or an op that is already cloned. @@ -733,6 +739,40 @@ class WhileLoopAnalysis { const absl::flat_hash_map& parameter_gtes_count, const absl::flat_hash_map& index_ranges) const; + + // Merges the new collective (instr) with the existing one stored in + // move_infos_[indices_to_merge[0]]. indices_to_merge.size() should be 1. + // This is done by adding the formating ops of the new collective and the new + // collective itself to the formatting ops of the existing collective. + + void MergeIntoExistingCollectivesForward( + HloInstruction* instr, std::vector formatting_ops, + std::vector dyn_updates, + std::vector indices_to_merge, + absl::flat_hash_map instruction_order); + // Merges the new collective (inst) and the existing collectives in + // indices_to_merge into a single entry in move_infos_. This is done because + // they mutually share at least one dynamic-update-slice so their dynamic + // update slices and formatting ops become inseparable. The smallest index in + // indices_to_merge is picked to hold the merged entry at the end and other + // entries in indices_to_merge are removed from move_infos_. + void MergeIntoExistingCollectivesForwardSink( + HloInstruction* instr, std::vector formatting_ops, + std::vector dyn_updates, + int64_t sliced_idx, std::vector output_indices, + std::vector indices_to_merge, + absl::flat_hash_map + index_per_dyn_update_slice, + absl::flat_hash_map instruction_order); + void MergeIntoExistingCollectives( + HloInstruction* instr, std::vector formatting_ops, + std::vector dyn_updates, + int64_t sliced_idx, std::vector output_indices, + std::vector indices_to_merge, + absl::flat_hash_map + index_per_dyn_update_slice, + absl::flat_hash_map instruction_order, + CollectivePipeliner::PipeliningDirection direction); void CollectCollectivesToMove( int64_t level_to_operate_on, CollectivePipeliner::PipeliningDirection direction, @@ -900,6 +940,7 @@ WhileLoopAnalysis::IsSupportedDynamicUpdateSlice( "slices being inserted or slice dim is not 0. slice_dim = " << *sliced_dim << " loop count = " << loop_iteration_count_->GetUnsignedValue(); + return std::nullopt; } if (!process_different_sized_options_) { if (!formatting_ops.empty()) { @@ -910,7 +951,7 @@ WhileLoopAnalysis::IsSupportedDynamicUpdateSlice( return std::nullopt; } auto dependencies_to_pipeline = CollectDependenciesToPipeline( - instr, absl::MakeConstSpan(formatting_ops)); + absl::MakeConstSpan({instr}), absl::MakeConstSpan(formatting_ops)); bool skip_because_not_same_size = false; // If any instruction in the dependency chain is not of the same size // then we abort for this instruction. @@ -986,6 +1027,137 @@ WhileLoopAnalysis::IsSupportedDynamicUpdateSlice( return std::make_pair(*sliced_dim, *output_idx); } +void WhileLoopAnalysis::MergeIntoExistingCollectivesForward( + HloInstruction* instr, std::vector formatting_ops, + std::vector dyn_updates, + std::vector indices_to_merge, + absl::flat_hash_map instruction_order) { + CHECK_EQ(indices_to_merge.size(), 1); + CHECK_EQ(dyn_updates.size(), 1); + int64_t target_idx = indices_to_merge[0]; + CHECK_EQ(move_infos_[target_idx].dynamic_update_slices.size(), 1); + CHECK_EQ(move_infos_[target_idx].collectives_to_move.size(), 1); + HloDynamicUpdateSliceInstruction* dyn_update = dyn_updates[0]; + CHECK_EQ(move_infos_[target_idx].dynamic_update_slices[0], dyn_update) + << "Not the same dynamic-update-slice for converging entry"; + absl::flat_hash_set existing_entry_instrs( + move_infos_[target_idx].formatting_ops.begin(), + move_infos_[target_idx].formatting_ops.end()); + existing_entry_instrs.insert(move_infos_[target_idx].collectives_to_move[0]); + // If instr is already in the set then this instruction is already + // in formatting-ops of the other one, so its already pipelined. + if (existing_entry_instrs.count(instr)) { + return; + } + move_infos_[target_idx].formatting_ops.push_back(instr); + for (auto* op : formatting_ops) { + if (!existing_entry_instrs.count(op)) { + move_infos_[target_idx].formatting_ops.push_back(op); + } + } + absl::c_sort(move_infos_[target_idx].formatting_ops, + [&](const HloInstruction* a, const HloInstruction* b) { + return instruction_order[a] < instruction_order[b]; + }); +} + +void WhileLoopAnalysis::MergeIntoExistingCollectivesForwardSink( + HloInstruction* instr, std::vector formatting_ops, + std::vector dyn_updates, + int64_t sliced_idx, std::vector output_indices, + std::vector indices_to_merge, + absl::flat_hash_map + index_per_dyn_update_slice, + absl::flat_hash_map instruction_order) { + CHECK(!indices_to_merge.empty()); + // Always pick the smallest group index to absorb the others. + const int64_t target_idx = *absl::c_min_element(indices_to_merge); + absl::flat_hash_set existing_formatting_ops( + move_infos_[target_idx].formatting_ops.begin(), + move_infos_[target_idx].formatting_ops.end()); + absl::flat_hash_set existing_collectives_to_move( + move_infos_[target_idx].collectives_to_move.begin(), + move_infos_[target_idx].collectives_to_move.end()); + absl::flat_hash_set existing_dyn_updates( + move_infos_[target_idx].dynamic_update_slices.begin(), + move_infos_[target_idx].dynamic_update_slices.end()); + + auto merge_entry_to_target = + [&](std::vector collectives_to_merge, + std::vector& formatting_ops_to_merge, + std::vector& dyn_updates_to_merge, + int64_t sliced_idx_to_merge, + std::vector& output_indices_to_merge) { + for (HloInstruction* op : collectives_to_merge) { + if (!existing_collectives_to_move.count(op)) { + move_infos_[target_idx].collectives_to_move.push_back(op); + } + } + for (HloInstruction* op : formatting_ops_to_merge) { + if (!existing_formatting_ops.count(op)) { + move_infos_[target_idx].formatting_ops.push_back(op); + } + } + for (int64_t i = 0; i < dyn_updates_to_merge.size(); ++i) { + HloDynamicUpdateSliceInstruction* dyn_update = + dyn_updates_to_merge[i]; + index_per_dyn_update_slice[dyn_update] = target_idx; + if (!existing_dyn_updates.count(dyn_update)) { + move_infos_[target_idx].dynamic_update_slices.push_back(dyn_update); + CHECK_EQ(sliced_idx_to_merge, move_infos_[target_idx].sliced_idx); + move_infos_[target_idx].output_indices.push_back( + output_indices_to_merge[i]); + } + } + }; + + // First merge the existing entries among themselves. + for (int64_t idx : indices_to_merge) { + if (idx == target_idx) { + continue; + } + // Merge idx to target_idx and delete idx. + merge_entry_to_target( + move_infos_[idx].collectives_to_move, move_infos_[idx].formatting_ops, + move_infos_[idx].dynamic_update_slices, move_infos_[idx].sliced_idx, + move_infos_[idx].output_indices); + move_infos_.erase(move_infos_.begin() + idx); + } + + // Now merge the current entry into the existing target entry. + merge_entry_to_target({instr}, formatting_ops, dyn_updates, sliced_idx, + output_indices); + + absl::c_sort(move_infos_[target_idx].formatting_ops, + [&](const HloInstruction* a, const HloInstruction* b) { + return instruction_order[a] < instruction_order[b]; + }); +} + +void WhileLoopAnalysis::MergeIntoExistingCollectives( + HloInstruction* instr, std::vector formatting_ops, + std::vector dyn_updates, + int64_t sliced_idx, std::vector output_indices, + std::vector indices_to_merge, + absl::flat_hash_map + index_per_dyn_update_slice, + absl::flat_hash_map instruction_order, + CollectivePipeliner::PipeliningDirection direction) { + if (direction == CollectivePipeliner::PipeliningDirection::kForwardSink) { + MergeIntoExistingCollectivesForwardSink( + instr, formatting_ops, dyn_updates, sliced_idx, output_indices, + indices_to_merge, index_per_dyn_update_slice, instruction_order); + return; + } + if (direction == CollectivePipeliner::PipeliningDirection::kForward) { + MergeIntoExistingCollectivesForward(instr, formatting_ops, dyn_updates, + indices_to_merge, instruction_order); + return; + } + CHECK(false) << "Backward pipelining is not supported in " + "MergeIntoExistingCollectives "; +} + void WhileLoopAnalysis::CollectCollectivesToMove( int64_t level_to_operate_on, CollectivePipeliner::PipeliningDirection direction, @@ -1054,66 +1226,86 @@ void WhileLoopAnalysis::CollectCollectivesToMove( } if (direction == CollectivePipeliner::PipeliningDirection::kForward || direction == CollectivePipeliner::PipeliningDirection::kForwardSink) { - auto [dyn_update, formatting_ops] = CheckStoreIntoSliceIsCompatible( + auto [dyn_updates, formatting_ops] = CheckStoreIntoSliceIsCompatible( instr, while_body, level_to_operate_on, pipeline_use_tree_, - acceptable_formatting); - if (dyn_update == nullptr) { + acceptable_formatting, + /*multi_dyn_updates=*/direction == + CollectivePipeliner::PipeliningDirection::kForwardSink); + if (dyn_updates.empty()) { VLOG(5) - << "Skipping " << instr->ToString() - << " because update users > 1 or single user is not the root of " - "computation"; + << "Skipping " << instr->name() + << " because storing into slice is not compatible with pipelining"; continue; } - std::optional> maybe_dus_info = - IsSupportedDynamicUpdateSlice(dyn_update, instr, formatting_ops, - direction, level_to_operate_on, - parameter_gtes_count, index_ranges); - if (!maybe_dus_info.has_value()) { - continue; + CHECK(direction != CollectivePipeliner::PipeliningDirection::kForward || + dyn_updates.size() == 1); + + // Collect the information for each dynamic-update-slice. Skip the + // collectives that have at least one unsupported dynamic-update-slice. + int64_t sliced_idx = -1; + std::vector output_indices; + bool skip_instr = false; + bool not_first_dyn_update = false; + for (HloDynamicUpdateSliceInstruction* dyn_update : dyn_updates) { + std::optional> maybe_dus_info = + IsSupportedDynamicUpdateSlice(dyn_update, instr, formatting_ops, + direction, level_to_operate_on, + parameter_gtes_count, index_ranges); + if (!maybe_dus_info.has_value()) { + VLOG(5) << "Skipping " << instr->name() << " because " + << dyn_update->name() << " is not supported"; + skip_instr = true; + break; + } + output_indices.push_back(maybe_dus_info->second); + if (not_first_dyn_update) { + // Dyn updates should not be writing into the same buffer. + CHECK_NE(dyn_update->operand(0), dyn_updates[0]->operand(0)); + // Dyn updates should have the same slice index. + CHECK_EQ(sliced_idx, maybe_dus_info->first); + } else { + sliced_idx = maybe_dus_info->first; + } + not_first_dyn_update = true; } - int64_t sliced_dim = maybe_dus_info->first; - int64_t output_idx = maybe_dus_info->second; - auto merge_as_formatting = - [this, &instruction_order]( - absl::flat_hash_map::iterator it, - HloInstruction* instr, HloInstruction* dyn_upd, - absl::Span formatting_ops) { - CHECK_EQ(move_infos_[it->second].dynamic_update_slice, dyn_upd) - << "Not the same dynamic-update-slice for converging entry"; - absl::flat_hash_set existing_entry_instrs( - move_infos_[it->second].formatting_ops.begin(), - move_infos_[it->second].formatting_ops.end()); - existing_entry_instrs.insert( - move_infos_[it->second].collective_to_move); - // If instr is already in the set then this instruction is already - // in formatting-ops of the other one, so its already pipelined. - if (existing_entry_instrs.count(instr)) { - return; - } - move_infos_[it->second].formatting_ops.push_back(instr); - for (auto* op : formatting_ops) { - if (!existing_entry_instrs.count(op)) { - move_infos_[it->second].formatting_ops.push_back(op); - } - } - absl::c_sort(move_infos_[it->second].formatting_ops, - [&](const HloInstruction* a, const HloInstruction* b) { - return instruction_order[a] < instruction_order[b]; - }); - }; - auto it = index_per_dyn_update_slice.find(dyn_update); - if (it != index_per_dyn_update_slice.end()) { - // Merge stuff with existing entry. - merge_as_formatting(it, instr, dyn_update, formatting_ops); + if (skip_instr) { continue; } - index_per_dyn_update_slice[dyn_update] = move_infos_.size(); - absl::c_sort(formatting_ops, - [&](const HloInstruction* a, const HloInstruction* b) { - return instruction_order[a] < instruction_order[b]; - }); - move_infos_.push_back({instr, dyn_update, std::move(formatting_ops), - sliced_dim, output_idx}); + CHECK_NE(sliced_idx, -1); + // First find the other collective groups that share at least one + // dynamic-update-slice with the current collective. + std::vector indices_to_merge; + for (HloDynamicUpdateSliceInstruction* dyn_update : dyn_updates) { + if (index_per_dyn_update_slice.find(dyn_update) != + index_per_dyn_update_slice.end()) { + int64_t index = index_per_dyn_update_slice[dyn_update]; + if (!absl::c_linear_search(indices_to_merge, index)) { + indices_to_merge.push_back(index); + } + } + } + // Merge with the existing group(s) if common instructions are found. + if (!indices_to_merge.empty()) { + MergeIntoExistingCollectives( + instr, formatting_ops, dyn_updates, sliced_idx, output_indices, + indices_to_merge, index_per_dyn_update_slice, instruction_order, + direction); + } else { + // This group is isolated from existing groups, so it should be inserted + // as a new entry to move_infos_. + absl::c_sort(formatting_ops, + [&](const HloInstruction* a, const HloInstruction* b) { + return instruction_order[a] < instruction_order[b]; + }); + for (HloDynamicUpdateSliceInstruction* dyn_update : dyn_updates) { + index_per_dyn_update_slice[dyn_update] = move_infos_.size(); + } + move_infos_.push_back({{instr}, + dyn_updates, + std::move(formatting_ops), + sliced_idx, + std::move(output_indices)}); + } } else { CHECK_EQ(direction, CollectivePipeliner::PipeliningDirection::kBackward); auto chain_collected = CollectChainsToPushBackwards( @@ -1128,7 +1320,7 @@ void WhileLoopAnalysis::CollectCollectivesToMove( continue; } move_infos_.push_back( - WhileMoveInfo{instr, nullptr, std::move(*chain_collected), -1, -1}); + WhileMoveInfo{{instr}, {}, std::move(*chain_collected), {}, {}}); } if (move_infos_.size() >= max_pipelining_per_loop_) { break; @@ -1139,9 +1331,11 @@ void WhileLoopAnalysis::CollectCollectivesToMove( } dus_index_map_.clear(); for (auto& to_move : move_infos_) { - HloInstruction* dus_index = to_move.dynamic_update_slice->mutable_operand( - to_move.dynamic_update_slice->first_index_operand_number() + - to_move.sliced_idx); + CHECK_EQ(to_move.dynamic_update_slices.size(), 1); + HloInstruction* dus_index = + to_move.dynamic_update_slices[0]->mutable_operand( + to_move.dynamic_update_slices[0]->first_index_operand_number() + + to_move.sliced_idx); auto it = dus_index_map_.find(dus_index); int64_t dus_index_tuple_position = dus_index_map_.size(); if (it != dus_index_map_.end()) { @@ -1232,11 +1426,9 @@ bool IsLoopInvariant( // Compute a shape that can hold a concatenation of tensors of shape base_shape. Shape ComputeFullOutputShape(const WhileMoveInfo& move_info, const Shape& base_shape) { + HloDynamicUpdateSliceInstruction* dus = move_info.dynamic_update_slices[0]; return ShapeUtil::PrependMajorDimension( - move_info.dynamic_update_slice->operand(0) - ->shape() - .dimensions()[move_info.sliced_idx], - base_shape); + dus->operand(0)->shape().dimensions()[move_info.sliced_idx], base_shape); } // Create zero of base type ptype and broadcast it to shape. @@ -1424,18 +1616,21 @@ absl::Status TransformLoopForward( int64_t count = 0; // Add all-reduces to duplicate into a set. for (auto& to_move : loop_analysis.GetMoveInfos()) { - to_skip_set.insert(to_move.collective_to_move); + CHECK_EQ(to_move.dynamic_update_slices.size(), 1); + to_skip_set.insert(to_move.collectives_to_move.front()); if (!to_move.formatting_ops.empty()) { formatting_map[to_move.formatting_ops.back()] = - to_move.collective_to_move; - } - const Shape& output_shape = to_move.formatting_ops.empty() - ? to_move.collective_to_move->shape() - : to_move.formatting_ops.back()->shape(); - if (!reuse_output_buffer(to_move.collective_to_move) || - output_shape != to_move.collective_to_move->operand(0)->shape()) { + to_move.collectives_to_move.front(); + } + const Shape& output_shape = + to_move.formatting_ops.empty() + ? to_move.collectives_to_move.front()->shape() + : to_move.formatting_ops.back()->shape(); + if (!reuse_output_buffer(to_move.collectives_to_move.front()) || + output_shape != + to_move.collectives_to_move.front()->operand(0)->shape()) { moves_requiring_special_output.push_back(count); - to_skip_set.insert(to_move.dynamic_update_slice); + to_skip_set.insert(to_move.dynamic_update_slices.front()); } ++count; } @@ -1532,7 +1727,7 @@ absl::Status TransformLoopForward( for (int i = 0; i < moves_requiring_special_output.size(); ++i) { HloInstruction* collective = loop_analysis.GetMoveInfos()[moves_requiring_special_output[i]] - .collective_to_move; + .collectives_to_move.front(); moves_requiring_special_output_to_idx[moves_requiring_special_output[i]] = operands_indices_count + i; new_parameter_shapes[operands_indices_count + i] = @@ -1545,7 +1740,8 @@ absl::Status TransformLoopForward( for (auto& move_info : loop_analysis.GetMoveInfos()) { auto pipelined_instrs = CollectDependenciesToPipeline( - move_info.collective_to_move, absl::MakeSpan(move_info.formatting_ops)); + absl::MakeConstSpan(move_info.collectives_to_move), + absl::MakeSpan(move_info.formatting_ops)); for (auto* pipelined : pipelined_instrs) { is_output_instruction[pipelined] = new_init_operands.size(); new_parameter_shapes.push_back(pipelined->shape()); @@ -1636,8 +1832,8 @@ absl::Status TransformLoopForward( const InstructionMap& pipelined_values_map, const WhileMoveInfo& move_info) -> absl::StatusOr { HloInstruction* processed = stacked_data->parent()->AddInstruction( - move_info.collective_to_move->CloneWithNewOperands( - move_info.collective_to_move->shape(), {stacked_data})); + move_info.collectives_to_move.front()->CloneWithNewOperands( + move_info.collectives_to_move.front()->shape(), {stacked_data})); UpdateInstructionChannelId(processed, next_channel_id); if (insert_non_alias_custom_call) { HloInstruction* level = @@ -1650,7 +1846,7 @@ absl::Status TransformLoopForward( } InstructionMap cloned_map = pipelined_values_map; - cloned_map[move_info.collective_to_move] = processed; + cloned_map[move_info.collectives_to_move.front()] = processed; for (auto* formatting_op : move_info.formatting_ops) { auto new_operands = MapNewOperands(formatting_op->operands(), cloned_map); processed = stacked_data->parent()->AddInstruction( @@ -1668,20 +1864,20 @@ absl::Status TransformLoopForward( HloInstruction* dus_index) -> absl::StatusOr { HloComputation* computation = stacked_data->parent(); const Shape& slice_target_shape = - move_info.collective_to_move->operand(0)->shape(); + move_info.collectives_to_move.front()->operand(0)->shape(); HloInstruction* sliced_data = data_to_slice; + HloDynamicUpdateSliceInstruction* dyn_update = + move_info.dynamic_update_slices.front(); PrimitiveType element_type = - move_info.dynamic_update_slice - ->operand( - move_info.dynamic_update_slice->first_index_operand_number() + - move_info.sliced_idx) + dyn_update + ->operand(dyn_update->first_index_operand_number() + + move_info.sliced_idx) ->shape() .element_type(); HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); std::vector indices( - move_info.dynamic_update_slice->operand_count() - - move_info.dynamic_update_slice->first_index_operand_number(), + dyn_update->operand_count() - dyn_update->first_index_operand_number(), zero); indices[move_info.sliced_idx] = dus_index; if (slice_target_shape != data_to_slice->shape()) { @@ -1699,24 +1895,23 @@ absl::Status TransformLoopForward( sliced_data, process_slice(sliced_data, pipelined_values_map, move_info)); return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - move_info.dynamic_update_slice->shape(), stacked_data, sliced_data, - indices)); + dyn_update->shape(), stacked_data, sliced_data, indices)); }; for (int i = 0; i < new_loop_analysis.GetMoveInfos().size(); ++i) { auto& move_info = new_loop_analysis.GetMoveInfos()[i]; + HloDynamicUpdateSliceInstruction* dyn_update = + move_info.dynamic_update_slices.front(); std::vector loop_output_to_replace; HloInstruction* parameter_instr = new_while_body->parameter_instructions()[0]; for (auto* user : new_while_loop->users()) { - if (user->tuple_index() != move_info.output_idx) { + if (user->tuple_index() != move_info.output_indices[0]) { continue; } loop_output_to_replace.push_back(user); } - const HloInstruction* dus_index_curr_iteration = - move_info.dynamic_update_slice->operand( - move_info.dynamic_update_slice->first_index_operand_number() + - move_info.sliced_idx); + const HloInstruction* dus_index_curr_iteration = dyn_update->operand( + dyn_update->first_index_operand_number() + move_info.sliced_idx); const int64_t offset_for_index = new_loop_analysis.GetDUSIndex(dus_index_curr_iteration) + initial_inputs; @@ -1736,24 +1931,22 @@ absl::Status TransformLoopForward( HloInstruction* output_dus_idx = loop_computation->AddInstruction(HloInstruction::CreateGetTupleElement( index_shape, new_while_loop, offset_for_index)); - HloInstruction* input_stacked_data = - move_info.dynamic_update_slice->mutable_operand(0); + HloInstruction* input_stacked_data = dyn_update->mutable_operand(0); HloInstruction* output_stacked_data = loop_computation->AddInstruction(HloInstruction::CreateGetTupleElement( - move_info.dynamic_update_slice->shape(), new_while_loop, - move_info.output_idx)); + dyn_update->shape(), new_while_loop, move_info.output_indices[0])); HloInstruction* input_data_to_slice = input_stacked_data; HloInstruction* output_data_to_slice = output_stacked_data; auto it = moves_requiring_special_output_to_idx.find(i); if (it != moves_requiring_special_output_to_idx.end()) { input_data_to_slice = new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( - move_info.collective_to_move->operand(0)->shape(), + move_info.collectives_to_move.front()->operand(0)->shape(), parameter_instr, it->second)); output_data_to_slice = loop_computation->AddInstruction( HloInstruction::CreateGetTupleElement( - move_info.collective_to_move->operand(0)->shape(), new_while_loop, - it->second)); + move_info.collectives_to_move.front()->operand(0)->shape(), + new_while_loop, it->second)); } TF_ASSIGN_OR_RETURN(input_stacked_data, extract_and_process_slice( @@ -1778,19 +1971,17 @@ absl::Status TransformLoopForward( auto* new_peeled_dus = input_stacked_data; if (it == moves_requiring_special_output_to_idx.end()) { new_peeled_dus = insert_slice( - move_info.collective_to_move->mutable_operand(0), + move_info.collectives_to_move.front()->mutable_operand(0), move_info.sliced_idx, - move_info.dynamic_update_slice->operand_count() - - move_info.dynamic_update_slice->first_index_operand_number(), - move_info.dynamic_update_slice->mutable_operand( - move_info.dynamic_update_slice->first_index_operand_number() + - move_info.sliced_idx), + dyn_update->operand_count() - + dyn_update->first_index_operand_number(), + dyn_update->mutable_operand(dyn_update->first_index_operand_number() + + move_info.sliced_idx), input_stacked_data); } + TF_RETURN_IF_ERROR(dyn_update->ReplaceAllUsesWith(new_peeled_dus)); TF_RETURN_IF_ERROR( - move_info.dynamic_update_slice->ReplaceAllUsesWith(new_peeled_dus)); - TF_RETURN_IF_ERROR(new_while_body->RemoveInstructionAndUnusedOperands( - move_info.dynamic_update_slice)); + new_while_body->RemoveInstructionAndUnusedOperands(dyn_update)); TF_RETURN_IF_ERROR(replace_instructions_with( absl::MakeSpan(loop_output_to_replace), output_stacked_data)); } @@ -1858,46 +2049,63 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, std::vector new_root_operands; absl::flat_hash_set indices_to_insert; const int64_t operands_indices_count = loop_init->operand_count(); - const int64_t new_loop_tuple_operand_count = operands_indices_count; absl::flat_hash_map> replacements; - new_parameter_shapes.resize(new_loop_tuple_operand_count); - new_root_operands.resize(new_loop_tuple_operand_count); - new_init_operands.resize(new_loop_tuple_operand_count); + new_parameter_shapes.resize(operands_indices_count); + new_root_operands.resize(operands_indices_count); + new_init_operands.resize(operands_indices_count); absl::flat_hash_set original_to_move_indices; // Initialize data structures with information about the outputs that need to // be sunk. + VLOG(1) << "Initial size for " << body_computation->name() << ": " + << operands_indices_count; + absl::flat_hash_map collective_to_new_tuple_index; for (auto& to_move : loop_analysis.GetMoveInfos()) { - HloInstruction* collective = to_move.collective_to_move; - Shape shape = - ComputeFullOutputShape(to_move, collective->operand(0)->shape()); - new_init_operands[to_move.output_idx] = - CreateZero(loop_computation, shape, shape.element_type()); - new_parameter_shapes[to_move.output_idx] = shape; - original_to_move_indices.insert(to_move.output_idx); - indices_to_insert.insert(to_move.output_idx); - new_root_operands[to_move.output_idx] = collective->mutable_operand(0); + for (HloInstruction* collective : to_move.collectives_to_move) { + Shape shape = + ComputeFullOutputShape(to_move, collective->operand(0)->shape()); + new_init_operands.push_back( + CreateZero(loop_computation, shape, shape.element_type())); + new_parameter_shapes.push_back(shape); + collective_to_new_tuple_index[collective] = new_root_operands.size(); + indices_to_insert.insert(new_root_operands.size()); + new_root_operands.push_back(collective->mutable_operand(0)); + } + CHECK_EQ(to_move.dynamic_update_slices.size(), + to_move.output_indices.size()); + for (int64_t i = 0; i < to_move.dynamic_update_slices.size(); ++i) { + int64_t output_idx = to_move.output_indices[i]; + original_to_move_indices.insert(output_idx); + } } // Initialize the data structures for output indices that aren't modified. for (int i = 0; i < loop_parameter->shape().tuple_shapes().size(); ++i) { if (original_to_move_indices.contains(i)) { + new_parameter_shapes[i] = loop_parameter->shape().tuple_shapes(i); + new_init_operands[i] = loop_init->mutable_operand(i); continue; } new_parameter_shapes[i] = loop_parameter->shape().tuple_shapes(i); new_init_operands[i] = loop_init->mutable_operand(i); new_root_operands[i] = while_body->root_instruction()->mutable_operand(i); } - + VLOG(1) << "Size of " << body_computation->name() + << " after adding collectives: " << new_root_operands.size(); // Collect instructions that are necessary for the execution of the sunk // instructions. If they are loop invariant they are stored as is, otherwise // the version for each iteration is accumulated in a buffer. + absl::flat_hash_set added_pipelined; for (auto& move_info : loop_analysis.GetMoveInfos()) { auto pipelined_instrs = CollectDependenciesToPipeline( - move_info.collective_to_move, absl::MakeSpan(move_info.formatting_ops)); + absl::MakeSpan(move_info.collectives_to_move), + absl::MakeSpan(move_info.formatting_ops)); for (auto* pipelined : pipelined_instrs) { if (pipelined->opcode() == HloOpcode::kConstant) { continue; } + if (added_pipelined.contains(pipelined)) { + continue; + } const bool is_loop_invariant = IsLoopInvariant(pipelined, invariant_cache); is_output_instruction[pipelined] = new_init_operands.size(); @@ -1907,6 +2115,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, CreateZero(loop_computation, pipelined->shape(), pipelined->shape().element_type())); new_root_operands.push_back(pipelined); + added_pipelined.insert(pipelined); continue; } Shape expanded_shape = @@ -1919,13 +2128,13 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, ShapeUtil::PrependMajorDimension(1, pipelined->shape()); HloInstruction* reshaped = body_computation->AddInstruction( HloInstruction::CreateReshape(extra_trivial_dim_shape, pipelined)); + Shape index_shape = + move_info.dynamic_update_slices.front()->index_shapes()[0]; std::vector indices( expanded_shape.dimensions_size(), - CreateZero(body_computation, - move_info.dynamic_update_slice->index_shapes()[0], - move_info.dynamic_update_slice->index_shapes()[0] - .element_type())); - indices[0] = move_info.dynamic_update_slice->index_operands()[0]; + CreateZero(body_computation, index_shape, + index_shape.element_type())); + indices[0] = move_info.dynamic_update_slices.front()->index_operands()[0]; HloInstruction* input = body_computation->AddInstruction(HloInstruction::CreateCustomCall( expanded_shape, @@ -1936,56 +2145,61 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, HloInstruction::CreateDynamicUpdateSlice(expanded_shape, input, reshaped, indices)); new_root_operands.push_back(reshaped); + added_pipelined.insert(pipelined); } } + VLOG(1) << "Size of " << body_computation->name() + << " after adding dependencies: " << new_root_operands.size(); std::unique_ptr new_parameter = HloInstruction::CreateParameter( 0, ShapeUtil::MakeTupleShape(new_parameter_shapes), absl::StrCat("sink_", loop_parameter->name())); // Insert inputs to the collective we are sinking in slices for the loop. for (auto& to_move : loop_analysis.GetMoveInfos()) { - if (!indices_to_insert.contains(to_move.output_idx)) { - continue; + for (HloInstruction* collective : to_move.collectives_to_move) { + int64_t new_tuple_index = collective_to_new_tuple_index[collective]; + HloInstruction* collective_operand = collective->mutable_operand(0); + HloInstruction* to_insert = + body_computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::PrependMajorDimension(1, collective_operand->shape()), + collective_operand)); + Shape expanded_shape = + ComputeFullOutputShape(to_move, collective_operand->shape()); + HloInstruction* input = + body_computation->AddInstruction(HloInstruction::CreateCustomCall( + expanded_shape, + {body_computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0((int32_t)new_tuple_index)))}, + "PlaceHolder")); + // All dyn update slices in this move_info have the same indices so it is + // safe to use the first one to create the indices. + HloDynamicUpdateSliceInstruction* dyn_update = + to_move.dynamic_update_slices[0]; + std::vector indices( + expanded_shape.dimensions_size(), + CreateZero(body_computation, dyn_update->index_shapes()[0], + dyn_update->index_shapes()[0].element_type())); + indices[0] = dyn_update->index_operands()[0]; + to_insert = body_computation->AddInstruction( + HloInstruction::CreateDynamicUpdateSlice(expanded_shape, input, + to_insert, indices)); + new_root_operands[new_tuple_index] = to_insert; } - HloInstruction* to_insert = - body_computation->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::PrependMajorDimension( - 1, new_root_operands[to_move.output_idx]->shape()), - new_root_operands[to_move.output_idx])); - Shape expanded_shape = ComputeFullOutputShape( - to_move, new_root_operands[to_move.output_idx]->shape()); - HloInstruction* input = - body_computation->AddInstruction(HloInstruction::CreateCustomCall( - expanded_shape, - {body_computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0((int32_t)to_move.output_idx)))}, - "PlaceHolder")); - std::vector indices( - expanded_shape.dimensions_size(), - CreateZero( - body_computation, to_move.dynamic_update_slice->index_shapes()[0], - to_move.dynamic_update_slice->index_shapes()[0].element_type())); - indices[0] = to_move.dynamic_update_slice->index_operands()[0]; - to_insert = body_computation->AddInstruction( - HloInstruction::CreateDynamicUpdateSlice(expanded_shape, input, - to_insert, indices)); - new_root_operands[to_move.output_idx] = to_insert; } - std::unique_ptr new_root_instr = - HloInstruction::CreateTuple(new_root_operands); // Mark for removal (by setting replacement entry to nullptr) the users of the // old parameters we are replacing for the loops. All the computation tree // for those should be not used in the new loop. for (auto* p_user : body_computation->parameter_instructions()[0]->users()) { CHECK_EQ(p_user->opcode(), HloOpcode::kGetTupleElement); const int64_t tuple_idx = p_user->tuple_index(); - if (!indices_to_insert.contains(tuple_idx)) { + if (!original_to_move_indices.contains(tuple_idx)) { continue; } replacements[p_user] = HloInstruction::CreateGetTupleElement(new_parameter.get(), tuple_idx); std::vector stack(p_user->users().begin(), p_user->users().end()); + new_root_operands[tuple_idx] = replacements[p_user].get(); while (!stack.empty()) { auto* u = stack.back(); stack.pop_back(); @@ -1998,6 +2212,8 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, } } } + std::unique_ptr new_root_instr = + HloInstruction::CreateTuple(new_root_operands); replacements[body_computation->parameter_instruction(0)] = std::move(new_parameter); replacements[body_computation->root_instruction()] = @@ -2037,6 +2253,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, TF_RETURN_IF_ERROR(output->ReplaceOperandWith(0, new_param)); TF_RETURN_IF_ERROR( old_operand_param->parent()->RemoveInstruction(old_operand_param)); + // TODO(sacer): Consider relaxing this to all inserted operands. if (insert_non_alias_custom_call && original_to_move_indices.contains(i)) { auto* old_operand = output->mutable_operand(1); auto* custom_call = @@ -2049,18 +2266,24 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, HloInstruction* new_while = loop_computation->AddInstruction(HloInstruction::CreateWhile( new_init->shape(), cloned_cond, cloned_body, new_init)); + // Create the new tuple with the original while tuple size. std::vector new_output_tuple; - new_output_tuple.resize(new_root_operands.size(), nullptr); + new_output_tuple.resize(operands_indices_count, nullptr); // Reproduce computation to the output after the loop on the full shape. for (auto& to_move : loop_analysis.GetMoveInfos()) { InstructionMap pipelined_map; - HloInstruction* to_sink = loop_computation->AddInstruction( - HloInstruction::CreateGetTupleElement(new_while, to_move.output_idx)); + for (int64_t i = 0; i < to_move.collectives_to_move.size(); ++i) { + HloInstruction* collective = to_move.collectives_to_move[i]; + int64_t gte_index = collective_to_new_tuple_index[collective]; + HloInstruction* to_sink = loop_computation->AddInstruction( + HloInstruction::CreateGetTupleElement(new_while, gte_index)); + pipelined_map[collective->mutable_operand(0)] = to_sink; + } const int64_t new_dim_limit = - to_move.dynamic_update_slice->shape().dimensions(0); - pipelined_map[to_move.collective_to_move->mutable_operand(0)] = to_sink; + to_move.dynamic_update_slices[0]->shape().dimensions(0); auto pipelined_instrs = CollectDependenciesToPipeline( - to_move.collective_to_move, absl::MakeSpan(to_move.formatting_ops)); + absl::MakeSpan(to_move.collectives_to_move), + absl::MakeSpan(to_move.formatting_ops)); for (auto* original_pipelined : pipelined_instrs) { if (original_pipelined->opcode() == HloOpcode::kConstant) { continue; @@ -2086,13 +2309,14 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, } } // Cloning the main instruction - HloInstruction* pipelined_instr_cloned = loop_computation->AddInstruction( - to_move.collective_to_move->CloneWithNewOperands( - ComputeFullOutputShape(to_move, - to_move.collective_to_move->shape()), - {to_sink})); - UpdateInstructionChannelId(pipelined_instr_cloned, next_channel_id); - pipelined_map[to_move.collective_to_move] = pipelined_instr_cloned; + for (HloInstruction* collective : to_move.collectives_to_move) { + HloInstruction* pipelined_instr_cloned = + loop_computation->AddInstruction(collective->CloneWithNewOperands( + ComputeFullOutputShape(to_move, collective->shape()), + {pipelined_map[collective->mutable_operand(0)]})); + UpdateInstructionChannelId(pipelined_instr_cloned, next_channel_id); + pipelined_map[collective] = pipelined_instr_cloned; + } absl::flat_hash_set to_add_batch_set; auto collect_operands = [&pipelined_map, &to_add_batch_set, loop_computation, @@ -2123,9 +2347,6 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, } return operands; }; - absl::flat_hash_set formatting_ops_set( - to_move.formatting_ops.begin(), to_move.formatting_ops.end()); - std::vector stack(1, to_move.collective_to_move); for (auto* current : to_move.formatting_ops) { if (IsLoopInvariant(current, invariant_cache)) { continue; @@ -2281,21 +2502,24 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, } CHECK(false) << "Unsupported instruction " << formatting_op->ToString(); } - HloInstruction* inserted_operand = - to_move.dynamic_update_slice->mutable_operand(1); - CHECK(pipelined_map.contains(inserted_operand)) - << "Expected to be processed"; - HloInstruction* expanded_inserted = pipelined_map[inserted_operand]; - if (!ShapeUtil::Compatible(expanded_inserted->shape(), - to_move.dynamic_update_slice->shape())) { - expanded_inserted = - loop_computation->AddInstruction(HloInstruction::CreateReshape( - to_move.dynamic_update_slice->shape(), expanded_inserted)); + for (int64_t i = 0; i < to_move.output_indices.size(); ++i) { + HloDynamicUpdateSliceInstruction* d_update = + to_move.dynamic_update_slices[i]; + HloInstruction* inserted_operand = d_update->mutable_operand(1); + CHECK(pipelined_map.contains(inserted_operand)) + << "Expected to be processed"; + HloInstruction* expanded_inserted = pipelined_map[inserted_operand]; + if (!ShapeUtil::Compatible(expanded_inserted->shape(), + d_update->shape())) { + expanded_inserted = + loop_computation->AddInstruction(HloInstruction::CreateReshape( + d_update->shape(), expanded_inserted)); + } + new_output_tuple[to_move.output_indices[i]] = expanded_inserted; } - new_output_tuple[to_move.output_idx] = expanded_inserted; } // Create new loop tuple replacement. - for (int i = 0; i < new_while->shape().tuple_shapes_size(); ++i) { + for (int64_t i = 0; i < operands_indices_count; ++i) { if (new_output_tuple[i] != nullptr) { continue; } @@ -2350,7 +2574,8 @@ static absl::Status TransformLoopBackward( int64_t count = 0; // Add instructions to duplicate into a set. for (auto& to_move : loop_analysis.GetMoveInfos()) { - HloInstruction* instr = to_move.collective_to_move; + CHECK_EQ(to_move.collectives_to_move.size(), 1); + HloInstruction* instr = to_move.collectives_to_move[0]; collective_to_move_map[instr] = count; is_pipelined_instruction.insert(instr); is_pipelined_instruction.insert(to_move.formatting_ops.begin(), @@ -2428,8 +2653,9 @@ static absl::Status TransformLoopBackward( for (int i = 0; i < loop_analysis.GetMoveInfos().size(); ++i) { const int64_t idx = i + loop_parameter->shape().tuple_shapes_size(); new_parameter_shapes[idx] = - loop_analysis.GetMoveInfos()[i].collective_to_move->shape(); - new_root_operands[idx] = loop_analysis.GetMoveInfos()[i].collective_to_move; + loop_analysis.GetMoveInfos()[i].collectives_to_move[0]->shape(); + new_root_operands[idx] = + loop_analysis.GetMoveInfos()[i].collectives_to_move[0]; TF_ASSIGN_OR_RETURN( new_init_operands[idx], CloneBackwardChain(*while_loop->parent(), @@ -2665,7 +2891,7 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( VLOG(1) << "Pipelining on direction: " << GetPipelineDirectionString(config_.pipelining_direction); for (HloInstruction* instruction : while_loop_instructions) { - VLOG(1) << "While: " << instruction->ToString(); + VLOG(1) << "While: " << instruction->name(); WhileLoopAnalysis loop_analysis( instruction, config_.max_pipelining_per_loop, config_.pipeline_use_tree, config_.process_different_sized_ops); @@ -2688,12 +2914,22 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( transformed_instructions += loop_analysis.GetMoveInfos().size(); VLOG(1) << "Found Collectives to optimize"; if (VLOG_IS_ON(1)) { + int64_t id = 0; for (auto& to_move : loop_analysis.GetMoveInfos()) { - VLOG(1) << "\t" << to_move.collective_to_move->ToString(); - if (to_move.dynamic_update_slice) { - VLOG(1) << "\t" << to_move.dynamic_update_slice->ToString(); + VLOG(1) << "Move info id: " << id++ << " with " + << to_move.collectives_to_move.size() << " collectives " + << to_move.dynamic_update_slices.size() + << " dynamic update slices" << to_move.formatting_ops.size() + << " formatting ops"; + for (HloInstruction* collective : to_move.collectives_to_move) { + VLOG(1) << "\t" << collective->name(); + } + for (int64_t i = 0; i < to_move.dynamic_update_slices.size(); ++i) { + HloDynamicUpdateSliceInstruction* dyn_update = + to_move.dynamic_update_slices[i]; + VLOG(1) << "\t\t" << dyn_update->name(); + VLOG(1) << "\t\t" << to_move.output_indices[i]; } - VLOG(1) << "\t" << to_move.output_idx; } } if (config_.pipelining_direction == PipeliningDirection::kForward) { diff --git a/xla/service/collective_pipeliner_test.cc b/xla/service/collective_pipeliner_test.cc index a2342fc96a796d..ab56cc1903cbcf 100644 --- a/xla/service/collective_pipeliner_test.cc +++ b/xla/service/collective_pipeliner_test.cc @@ -2958,37 +2958,621 @@ ENTRY entry { /*acceptable_formatting=*/HloPredicateIsNotOp) .value()); XLA_VLOG_LINES(1, module->ToString()); - // Checks if i has at least one operand whose operand is a custom-call with - // target "SunkByPreviousStep". - std::function has_operands_operand_custom_call = - [&](HloInstruction* i) -> bool { - for (const HloInstruction* operand : i->operands()) { - if (absl::c_any_of( - operand->operands(), [](const HloInstruction* operands_operand) { - return operands_operand->IsCustomCall("SunkByPreviousStep"); - })) { - return true; + // Return the closest all-reduce in the user subtree rooted at instruction i. + std::function find_all_reduce = + [&](const HloInstruction* i) -> const HloInstruction* { + std::queue queue; + queue.push(i); + absl::flat_hash_set visited; + while (!queue.empty()) { + const HloInstruction* curr_inst = queue.front(); + queue.pop(); + for (HloInstruction* operand : curr_inst->operands()) { + if (operand->opcode() == HloOpcode::kAllReduce) { + return operand; + } + if (visited.insert(operand).second) { + queue.push(operand); + } } } - return false; + return nullptr; }; - // The following two tuples should have the above function return true. - // - one tuple in the entry computation. - // - the root tuple of the while loop. - int64_t num_desired_tuples = 0; - for (HloInstruction* inst : module->entry_computation()->instructions()) { - if (inst->opcode() == HloOpcode::kTuple && - has_operands_operand_custom_call(inst)) { - num_desired_tuples++; - continue; - } - if (inst->opcode() == HloOpcode::kWhile) { - CHECK(has_operands_operand_custom_call( - inst->while_body()->root_instruction())); - num_desired_tuples++; - } - } - CHECK_EQ(num_desired_tuples, 2); + // Check if root has the two all-reduces in the operand subtree where one is + // an ancestor of the other. + const HloInstruction* all_reduce1 = + find_all_reduce(module->entry_computation()->root_instruction()); + EXPECT_NE(all_reduce1, nullptr); + const HloInstruction* all_reduce2 = find_all_reduce(all_reduce1); + EXPECT_NE(all_reduce2, nullptr); + EXPECT_THAT(all_reduce2, op::AllReduce(op::GetTupleElement(op::While()))); +} + +TEST_F(CollectivePipelinerTest, ForwardSinkFirstDimNotMatchingLoopCount) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[5,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[5,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[5,8,128] get-tuple-element(param), index=1 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + c = bf16[] custom-call(), custom_call_target="Boh" + b = bf16[1,8,128] broadcast(c), dimensions={} + a = bf16[1,8,128] add(ar.1, b) + dynamic-update-slice.35 = bf16[5,8,128] dynamic-update-slice(get-tuple-element.395, a, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[5,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35), control-predecessors={select.1348} +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[5,8,128] parameter(0) + p1 = bf16[3,8,128] parameter(1) + tuple = (s32[], bf16[5,8,128], bf16[3,8,128]) tuple(c0, p0, p1) + while = (s32[], bf16[5,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[5,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_FALSE(RunOptimizer(module.get(), /*last_run=*/true, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); +} + +TEST_F(CollectivePipelinerTest, ForwardSinkNotFirstDim) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + %c = bf16[] custom-call(), custom_call_target="Boh" + %b = bf16[1,8,128] broadcast(c), dimensions={} + %a = bf16[1,8,128] add(ar.1, b) + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, a, constant.2561, select.1348, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35), control-predecessors={select.1348} +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_FALSE(RunOptimizer(module.get(), /*last_run=*/true, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); +} + +TEST_F(CollectivePipelinerTest, CollectiveWithMultipleDUS) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.396 = bf16[3,8,128] get-tuple-element(param), index=2 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=3 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) + c2 = bf16[] constant(2.0) + bc = bf16[1,8,128] broadcast(c2) + mul2 = bf16[1,8,128] multiply(ar.1, bc) + mul3 = bf16[1,8,128] multiply(mul2, ar.1) + mul4 = bf16[1,8,128] multiply(mul3, mul) + dynamic-update-slice.36 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.396, mul4, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-update-slice.36, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + const HloInstruction* while_instr = + FindInstruction(module.get(), HloOpcode::kWhile); + EXPECT_TRUE( + absl::c_any_of(while_instr->users(), [](const HloInstruction* user) { + return absl::c_any_of( + user->users(), [](const HloInstruction* user_user) { + return user_user->opcode() == HloOpcode::kAllReduce; + }); + })); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kGetTupleElement); + const HloInstruction* new_tuple = + module->entry_computation()->root_instruction()->operand(0); + EXPECT_EQ(new_tuple->opcode(), HloOpcode::kTuple); + // There should be two reshapes in this tuple (replacing the two + // dynamic-update-slices). + EXPECT_EQ(absl::c_count_if(new_tuple->operands(), + [](const HloInstruction* operand) { + return operand->opcode() == HloOpcode::kReshape; + }), + 2); +} + +TEST_F(CollectivePipelinerTest, CollectiveWithMultipleDUSNotLastRun) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.396 = bf16[3,8,128] get-tuple-element(param), index=2 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=3 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) + c2 = bf16[] constant(2.0) + bc = bf16[1,8,128] broadcast(c2) + mul2 = bf16[1,8,128] multiply(ar.1, bc) + mul3 = bf16[1,8,128] multiply(mul2, ar.1) + mul4 = bf16[1,8,128] multiply(mul3, mul) + dynamic-update-slice.36 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.396, mul4, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-update-slice.36, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/false, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + const HloInstruction* while_instr = + FindInstruction(module.get(), HloOpcode::kWhile); + EXPECT_TRUE( + absl::c_any_of(while_instr->users(), [](const HloInstruction* user) { + return absl::c_any_of( + user->users(), [](const HloInstruction* user_user) { + return user_user->opcode() == HloOpcode::kAllReduce; + }); + })); + EXPECT_EQ(module->entry_computation()->root_instruction()->opcode(), + HloOpcode::kGetTupleElement); + const HloInstruction* new_tuple = + module->entry_computation()->root_instruction()->operand(0); + EXPECT_EQ(new_tuple->opcode(), HloOpcode::kTuple); + // There should be two reshapes in this tuple (replacing the two + // dynamic-update-slices). + EXPECT_EQ(absl::c_count_if(new_tuple->operands(), + [](const HloInstruction* operand) { + return operand->opcode() == HloOpcode::kReshape; + }), + 2); +} + +TEST_F(CollectivePipelinerTest, CollectiveWithMultipleDUSSameBuffer) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) + c2 = bf16[] constant(2.0) + bc = bf16[1,8,128] broadcast(c2) + mul2 = bf16[1,8,128] multiply(ar.1, bc) + mul3 = bf16[1,8,128] multiply(mul2, ar.1) + mul4 = bf16[1,8,128] multiply(mul3, mul) + dynamic-update-slice.36 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul4, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-update-slice.36, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_FALSE(RunOptimizer(module.get(), /*last_run=*/true, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); +} + +TEST_F(CollectivePipelinerTest, MergeTwoCollectivesEachWithTwoDUS) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.396 = bf16[3,8,128] get-tuple-element(param), index=2 + get-tuple-element.397 = bf16[3,8,128] get-tuple-element(param), index=3 + get-tuple-element.398 = bf16[3,8,128] get-tuple-element(param), index=4 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=5 + get-tuple-element.36 = bf16[3,8,128] get-tuple-element(param), index=6 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + + // ar.1 is used by dynamic-update-slice.35 and dynamic-update-slice.36 + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) + c2 = bf16[] constant(2.0) + bc = bf16[1,8,128] broadcast(c2) + mul2 = bf16[1,8,128] multiply(ar.1, bc) + mul3 = bf16[1,8,128] multiply(mul2, ar.1) + mul4 = bf16[1,8,128] multiply(mul3, mul) + dynamic-update-slice.36 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.396, mul4, select.1348, constant.2561, constant.2561) + + // ar.1 is used by dynamic-update-slice.37 and dynamic-update-slice.38 + // dynamic-update-slice.37 actually uses both ar.1 and ar.2 + dynamic-slice.100 = bf16[1,8,128] dynamic-slice(get-tuple-element.36, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul.1 = bf16[1,8,128] multiply(dynamic-slice.100, dynamic-slice.99) + ar.2 = bf16[1,8,128] all-reduce(mul.1), replica_groups={}, to_apply=add, channel_id=1 + divide = bf16[1,8,128] divide(ar.1, ar.2) + dynamic-update-slice.37 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.397, divide, select.1348, constant.2561, constant.2561) + mul.2 = bf16[1,8,128] multiply(ar.2, ar.2) + abs = bf16[1,8,128] abs(mul.2) + dynamic-update-slice.38 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.398, abs, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-update-slice.36, dynamic-update-slice.37, dynamic-update-slice.38, get-tuple-element.35, get-tuple-element.36) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + p1 = bf16[3,8,128] parameter(1) + p2 = bf16[3,8,128] parameter(2) + p3 = bf16[3,8,128] parameter(3) + p4 = bf16[3,8,128] parameter(4) + p5 = bf16[3,8,128] parameter(5) + + tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p1, p2, p3, p4, p5) + while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::GetTupleElement(op::Tuple( + op::GetTupleElement(op::While()), op::Reshape(op::Reduce()), + op::Reshape(op::Multiply()), op::Reshape(op::Divide()), + op::Reshape(op::Abs()), op::GetTupleElement(op::While()), + op::GetTupleElement(op::While())))); +} + +TEST_F(CollectivePipelinerTest, MergeTwoCollectivesEachWithTwoDUSNotLastRun) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.396 = bf16[3,8,128] get-tuple-element(param), index=2 + get-tuple-element.397 = bf16[3,8,128] get-tuple-element(param), index=3 + get-tuple-element.398 = bf16[3,8,128] get-tuple-element(param), index=4 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=5 + get-tuple-element.36 = bf16[3,8,128] get-tuple-element(param), index=6 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + + // ar.1 is used by dynamic-update-slice.35 and dynamic-update-slice.36 + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) + c2 = bf16[] constant(2.0) + bc = bf16[1,8,128] broadcast(c2) + mul2 = bf16[1,8,128] multiply(ar.1, bc) + mul3 = bf16[1,8,128] multiply(mul2, ar.1) + mul4 = bf16[1,8,128] multiply(mul3, mul) + dynamic-update-slice.36 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.396, mul4, select.1348, constant.2561, constant.2561) + + // ar.1 is used by dynamic-update-slice.37 and dynamic-update-slice.38 + // dynamic-update-slice.37 actually uses both ar.1 and ar.2 + dynamic-slice.100 = bf16[1,8,128] dynamic-slice(get-tuple-element.36, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul.1 = bf16[1,8,128] multiply(dynamic-slice.100, dynamic-slice.99) + ar.2 = bf16[1,8,128] all-reduce(mul.1), replica_groups={}, to_apply=add, channel_id=1 + divide = bf16[1,8,128] divide(ar.1, ar.2) + dynamic-update-slice.37 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.397, divide, select.1348, constant.2561, constant.2561) + mul.2 = bf16[1,8,128] multiply(ar.2, ar.2) + abs = bf16[1,8,128] abs(mul.2) + dynamic-update-slice.38 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.398, abs, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-update-slice.36, dynamic-update-slice.37, dynamic-update-slice.38, get-tuple-element.35, get-tuple-element.36) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + p1 = bf16[3,8,128] parameter(1) + p2 = bf16[3,8,128] parameter(2) + p3 = bf16[3,8,128] parameter(3) + p4 = bf16[3,8,128] parameter(4) + p5 = bf16[3,8,128] parameter(5) + + tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p1, p2, p3, p4, p5) + while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/false, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::GetTupleElement(op::Tuple( + op::GetTupleElement(op::While()), op::Reshape(op::Reduce()), + op::Reshape(op::Multiply()), op::Reshape(op::Divide()), + op::Reshape(op::Abs()), op::GetTupleElement(op::While()), + op::GetTupleElement(op::While())))); } } // namespace diff --git a/xla/tests/collective_pipeliner_execution_test.cc b/xla/tests/collective_pipeliner_execution_test.cc index d06981fafa7267..c9175b653da09e 100644 --- a/xla/tests/collective_pipeliner_execution_test.cc +++ b/xla/tests/collective_pipeliner_execution_test.cc @@ -1298,5 +1298,100 @@ ENTRY entry { ErrorSpec{0.1, 0.1})); } +TEST_F(CollectivePipelinerExecutionTest, MergeTwoCollectivesEachWithTwoDUS) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.396 = bf16[3,8,128] get-tuple-element(param), index=2 + get-tuple-element.397 = bf16[3,8,128] get-tuple-element(param), index=3 + get-tuple-element.398 = bf16[3,8,128] get-tuple-element(param), index=4 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=5 + get-tuple-element.36 = bf16[3,8,128] get-tuple-element(param), index=6 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + + // ar.1 is used by dynamic-update-slice.35 and dynamic-update-slice.36 + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] negate(mul) + b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2} + constant = bf16[] constant(0) + reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561) + c2 = bf16[] constant(2.0) + bc = bf16[1,8,128] broadcast(c2) + mul2 = bf16[1,8,128] multiply(ar.1, bc) + mul3 = bf16[1,8,128] multiply(mul2, ar.1) + mul4 = bf16[1,8,128] multiply(mul3, mul) + dynamic-update-slice.36 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.396, mul4, select.1348, constant.2561, constant.2561) + + // ar.1 is used by dynamic-update-slice.37 and dynamic-update-slice.38 + // dynamic-update-slice.37 actually uses both ar.1 and ar.2 + dynamic-slice.100 = bf16[1,8,128] dynamic-slice(get-tuple-element.36, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul.1 = bf16[1,8,128] multiply(dynamic-slice.100, dynamic-slice.99) + ar.2 = bf16[1,8,128] exponential(mul.1) + divide = bf16[1,8,128] divide(ar.1, ar.2) + dynamic-update-slice.37 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.397, divide, select.1348, constant.2561, constant.2561) + mul.2 = bf16[1,8,128] multiply(ar.2, ar.2) + abs = bf16[1,8,128] abs(mul.2) + dynamic-update-slice.38 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.398, abs, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-update-slice.36, dynamic-update-slice.37, dynamic-update-slice.38, get-tuple-element.35, get-tuple-element.36) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + p1 = bf16[3,8,128] parameter(1) + p2 = bf16[3,8,128] parameter(2) + p3 = bf16[3,8,128] parameter(3) + p4 = bf16[3,8,128] parameter(4) + p5 = bf16[3,8,128] parameter(5) + + tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p1, p2, p3, p4, p5) + ROOT while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string).value(); + auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value(); + + EXPECT_TRUE( + RunOptimizer(module.get(), /*last_run=*/true, 0, + /*should_process=*/ + HloPredicateIsOp, + CollectivePipeliner::PipeliningDirection::kForwardSink, + /*pipeline_use_tree=*/true) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + XLA_VLOG_LINES(1, module2->ToString()); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2), + ErrorSpec{0.1, 0.1})); +} + } // namespace } // namespace xla From 30de9c37522f9030e8bb8c1a1afdb6c94691200c Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Tue, 23 Jul 2024 00:34:40 -0700 Subject: [PATCH 079/376] PR #14900: [PJRT:GPU] Propagate arg and result info from MLIR to XLA Compile method Imported from GitHub PR https://github.com/openxla/xla/pull/14900 In MLIR flavor of the PjRtStreamExecutorClient::Compile method, we now transfer the argument layouts and result layout from MLIR code to compile options. If compile options already specified argument layouts, we ignore layouts from MLIR. We also make sure that the argument/result layouts are preserved when SPMD needs to canonicalize layouts after resharding parameters and/or layouts. Copybara import of the project: -- bbe0015acb39a22de293049a35f6ea847c63c720 by Jaroslav Sevcik : Revert "Reverts 5b619ac97f0b15cdadf1eb67ac2d5234a17dbfea" This reverts commit e21e3e0165c83ee659f4d681ac606b9fc6ad4172. -- b33b94463837ff35b3f30e9b5727ebe57e324a90 by Jaroslav Sevcik : Fix and test -- 9c4f6e7c5de7101a9658f0807ba48d3baa8e32a6 by Jaroslav Sevcik : Make sure the callback only mutates local data -- 179adcb1e1930f0a4912b907f3c6d8eef553efed by Jaroslav Sevcik : Bake argument layouts into options, use them in canonicalization callback -- 1042f7ee78f5ee8d45ccfd57f67d3657b5a9cf07 by Jaroslav Sevcik : Change shardings to only use 2 GPUs Merging this change closes #14900 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/14900 from jaro-sevcik:preserve-argument-layouts-on-canonicalization 1042f7ee78f5ee8d45ccfd57f67d3657b5a9cf07 PiperOrigin-RevId: 655053793 --- xla/pjrt/gpu/BUILD | 1 + xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 192 ++++++++++++++++++++++++ xla/pjrt/pjrt_stream_executor_client.cc | 39 ++++- 3 files changed, 229 insertions(+), 3 deletions(-) diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD index 2bffeac11e34d6..b7e06b628fc49c 100644 --- a/xla/pjrt/gpu/BUILD +++ b/xla/pjrt/gpu/BUILD @@ -157,6 +157,7 @@ xla_cc_test( "//xla/ffi", "//xla/ffi:ffi_api", "//xla/pjrt:host_memory_spaces", + "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index c839312edd0909..a664a11352a023 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/gpu/gpu_topology.h" #include "xla/pjrt/host_memory_spaces.h" +#include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" @@ -1149,5 +1150,196 @@ TEST(StreamExecutorGpuClientTest, EXPECT_EQ(memory_kinds[0][1], "pinned_host"); } +TEST(StreamExecutorGpuClientTest, MlirParameterHostMemorySpaceIsSetInHlo) { + constexpr char kMlirH2D[] = + R"( + func.func public @main(%arg0: tensor<8x2xi32> { + mhlo.layout_mode = "{1,0}", + mhlo.memory_kind = "pinned_host", + mhlo.sharding = "{devices=[2,2]<=[4]}" + }) -> (tensor<8x2xi32> { + jax.result_info = "", + mhlo.layout_mode = "default", + mhlo.memory_kind = "device", + mhlo.sharding = "{devices=[2,2]<=[4]}"}) { + %0 = stablehlo.custom_call @annotate_device_placement(%arg0) { + has_side_effect = true, + mhlo.frontend_attributes = {_xla_buffer_placement = "device"} + } : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseMlirModuleString(kMlirH2D, context)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, client->Compile(*module, {})); + TF_ASSERT_OK_AND_ASSIGN(auto modules, executable->GetHloModules()); + + auto first_param_layout = + modules[0]->entry_computation_layout().parameter_layout(0).layout(); + EXPECT_EQ(first_param_layout.memory_space(), Layout::kHostMemorySpace); + auto result_layout = + modules[0]->entry_computation_layout().result_layout().layout(); + EXPECT_EQ(result_layout.memory_space(), Layout::kDefaultMemorySpace); +} + +TEST(StreamExecutorGpuClientTest, MlirResultHostMemorySpaceIsSetInHlo) { + constexpr char kMlirD2H[] = + R"( + func.func public @main(%arg0: tensor<8x2xi32> { + mhlo.layout_mode = "{1,0}", + mhlo.memory_kind = "device", + mhlo.sharding = "{devices=[2,2]<=[4]}" + }) -> (tensor<8x2xi32> { + jax.result_info = "", + mhlo.layout_mode = "default", + mhlo.memory_kind = "pinned_host", + mhlo.sharding = "{devices=[2,2]<=[4]}"}) { + %0 = stablehlo.custom_call @annotate_device_placement(%arg0) { + has_side_effect = true, + mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"} + } : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseMlirModuleString(kMlirD2H, context)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, client->Compile(*module, {})); + TF_ASSERT_OK_AND_ASSIGN(auto modules, executable->GetHloModules()); + + auto first_param_layout = + modules[0]->entry_computation_layout().parameter_layout(0).layout(); + EXPECT_EQ(first_param_layout.memory_space(), Layout::kDefaultMemorySpace); + auto result_layout = + modules[0]->entry_computation_layout().result_layout().layout(); + EXPECT_EQ(result_layout.memory_space(), Layout::kHostMemorySpace); +} + +TEST(StreamExecutorGpuClientTest, MlirParameterLayoutIsSetInHlo) { + constexpr char kMlirWithParameterLayout[] = + R"( + func.func public @main(%arg0: tensor<2x2x2xi32> { + mhlo.layout_mode = "{0, 2, 1}" + }) -> (tensor<2x2x2xi32> { + jax.result_info = "", + mhlo.layout_mode = "default"}) { + return %arg0 : tensor<2x2x2xi32> + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(auto module, xla::ParseMlirModuleString( + kMlirWithParameterLayout, context)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, client->Compile(*module, {})); + TF_ASSERT_OK_AND_ASSIGN(auto modules, executable->GetHloModules()); + + auto first_param_layout = + modules[0]->entry_computation_layout().parameter_layout(0).layout(); + EXPECT_EQ(first_param_layout, Layout({0, 2, 1})); +} + +TEST(StreamExecutorGpuClientTest, MlirParameterLayoutFromOptionsIsSetInHlo) { + constexpr char kMlirCopy[] = + R"( + func.func public @main(%arg0: tensor<2x2x2xi32> { + mhlo.layout_mode = "default" + }) -> (tensor<2x2x2xi32> { + jax.result_info = "", + mhlo.layout_mode = "default"}) { + return %arg0 : tensor<2x2x2xi32> + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseMlirModuleString(kMlirCopy, context)); + + xla::CompileOptions options; + options.argument_layouts = { + {ShapeUtil::MakeShapeWithDenseLayout(S32, {2, 2, 2}, {0, 2, 1})}}; + TF_ASSERT_OK_AND_ASSIGN(auto executable, client->Compile(*module, options)); + TF_ASSERT_OK_AND_ASSIGN(auto modules, executable->GetHloModules()); + + auto first_param_layout = + modules[0]->entry_computation_layout().parameter_layout(0).layout(); + EXPECT_EQ(first_param_layout, Layout({0, 2, 1})); +} + +TEST(StreamExecutorGpuClientTest, + MlirResultHostMemorySpaceIsSetInHloWithShardingPropagation) { + constexpr absl::string_view mlir_mul_explicit_sharding_layout_and_memory = + R"mlir( + module @jit_f attributes { + mhlo.num_partitions = 2 : i32, + mhlo.num_replicas = 1 : i32 + } { + func.func public @main(%arg0: tensor<8x2xi32> { + mhlo.layout_mode = "{1,0}", + mhlo.memory_kind = "device", + mhlo.sharding = "{devices=[1,2]<=[2]}" + }) -> (tensor<8x2xi32> { + jax.result_info = "", + mhlo.layout_mode = "{0,1}", + mhlo.memory_kind = "pinned_host" + }) { + %c = stablehlo.constant dense<2> : tensor + %0 = stablehlo.broadcast_in_dim %c, dims = [] + : (tensor) -> tensor<8x2xi32> + %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi32> + %2 = stablehlo.custom_call @Sharding(%1) { + mhlo.sharding = "{devices=[1,2]<=[2]}" + } : (tensor<8x2xi32>) -> tensor<8x2xi32> + %3 = stablehlo.custom_call @annotate_device_placement(%2) { + has_side_effect = true, + mhlo.frontend_attributes = { + _xla_buffer_placement = "pinned_host" + } + } : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %3 : tensor<8x2xi32> + } + })mlir"; + + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN( + auto module, xla::ParseMlirModuleString( + mlir_mul_explicit_sharding_layout_and_memory, context)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + + xla::CompileOptions options; + options.executable_build_options.set_num_partitions(2) + .set_use_spmd_partitioning(true) + .set_allow_spmd_sharding_propagation_to_output({true}); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, client->Compile(*module, options)); + TF_ASSERT_OK_AND_ASSIGN(auto modules, executable->GetHloModules()); + + auto first_param_layout = + modules[0]->entry_computation_layout().parameter_layout(0).layout(); + EXPECT_EQ(first_param_layout.memory_space(), Layout::kDefaultMemorySpace); + auto result_layout = + modules[0]->entry_computation_layout().result_layout().layout(); + EXPECT_EQ(result_layout, + Layout({0, 1}).set_memory_space(Layout::kHostMemorySpace)); +} + } // namespace } // namespace xla diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index 464ba2ef10359c..9d3c820550a3b7 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -3400,11 +3400,13 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) { build_options.set_device_allocator(allocator()); } - auto layout_callback = [local_client = client()](const HloModule& module) + auto layout_callback = [local_client = client(), + options](const HloModule& module) -> absl::StatusOr, Shape>> { - ExecutableBuildOptions build_options; + ExecutableBuildOptions build_options = options->executable_build_options; std::vector argument_layout_pointers; - std::optional> argument_layouts; + std::optional> argument_layouts = + options->argument_layouts; Shape result_layout; TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( XlaComputation(module.ToProto()), @@ -3526,6 +3528,37 @@ PjRtStreamExecutorClient::Compile(mlir::ModuleOp module, exec_build_options.has_debug_options() ? exec_build_options.debug_options().xla_use_shardy() : false)); + + // If the compile options specify argument layout, then let's + // fall back to using the options to determine layouts. + if (options.argument_layouts) { + return Compile(xla_computation, options); + } + + TF_ASSIGN_OR_RETURN(std::vector arg_layout_modes, + GetArgLayoutModes(module)); + TF_ASSIGN_OR_RETURN(std::vector out_layout_modes, + GetOutputLayoutModes(module)); + TF_ASSIGN_OR_RETURN(std::vector arg_memory_spaces, + GetArgMemoryKinds(module)); + TF_ASSIGN_OR_RETURN(std::vector out_memory_spaces, + GetOutputMemoryKinds(module)); + + // This call will update result_layout in options.executable_build_options + // (in addition to returning the argument layouts). + TF_ASSIGN_OR_RETURN(auto arg_layouts_and_pointers, + LayoutModesToXla( + xla_computation, arg_layout_modes, out_layout_modes, + arg_memory_spaces, out_memory_spaces, + [this](Shape shape) -> absl::StatusOr { + return this->client() + ->backend() + .transfer_manager() + ->ChooseCompactLayoutForShape(shape); + }, + options.executable_build_options)); + + options.argument_layouts = arg_layouts_and_pointers.first; return Compile(xla_computation, options); } From 77ed71827f607bc12b940c6fb8f45fe8162ef9cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Tue, 23 Jul 2024 02:10:13 -0700 Subject: [PATCH 080/376] [XLA:CPU] Support `add-dependency` in thunks runtime. PiperOrigin-RevId: 655081835 --- xla/service/cpu/thunk_emitter.cc | 1 + xla/tests/BUILD | 1 + xla/tests/token_hlo_test.cc | 36 +++++++++++++++++++++++++++++++- 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 028b30c4e1dd33..4ae71bff91f88a 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -152,6 +152,7 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( // No-op operations that are used to provide more metadata about the HLO // dataflow graph. case HloOpcode::kAfterAll: // Defines an execution order. + case HloOpcode::kAddDependency: // Defines an execution order. case HloOpcode::kDomain: // Defines an HLO domain. case HloOpcode::kOptimizationBarrier: // Prevents moving ops past barrier. return ThunkSequence::Empty(); diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 230eb6c8772cd0..701ed1fca2abb0 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -1827,6 +1827,7 @@ xla_test( xla_test( name = "token_hlo_test", srcs = ["token_hlo_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", diff --git a/xla/tests/token_hlo_test.cc b/xla/tests/token_hlo_test.cc index 8991e97916a753..cb6d6f7c3b8bb4 100644 --- a/xla/tests/token_hlo_test.cc +++ b/xla/tests/token_hlo_test.cc @@ -180,7 +180,7 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { } } -XLA_TEST_F(TokenHloTest, AddDependency) { +XLA_TEST_F(TokenHloTest, AddDependencyOfParameter) { if (IsMlirLoweringEnabled()) { // This test generates invalid HLO. The after-all op only takes tokens. GTEST_SKIP() << "Invalid HLO unsupported by MLIR"; @@ -212,6 +212,40 @@ ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1})); } +XLA_TEST_F(TokenHloTest, AddDependencyOfOperation) { + if (IsMlirLoweringEnabled()) { + // This test generates invalid HLO. The after-all op only takes tokens. + GTEST_SKIP() << "Invalid HLO unsupported by MLIR"; + } + std::string module_string = R"( +HloModule AddDependency, is_scheduled=true + +// Computes (p0 + 42) * ( -(p1 - 45)) +// where there is a dependency from the add to the negation using a token +// with after-all and add-dependency instructions. +ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + + %forty_two = f32[] constant(42.0) + %add = f32[] add(f32[] %p0, f32[] %forty_two) + %forty_five = f32[] constant(45.0) + %sub = f32[] subtract(f32[] %p1, f32[] %forty_five) + %token0 = token[] after-all(f32[] %add) + %sub_after_token = f32[] add-dependency(f32[] %sub, token[] %token0) + %neg = f32[] negate(f32[] %sub_after_token) + ROOT %product = f32[] multiply(f32[] %add, f32[] %neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR0(10.0); + auto p1 = LiteralUtil::CreateR0(3.0); + auto expected = LiteralUtil::CreateR0(2184.0); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1})); +} + XLA_TEST_F(TokenHloTest, AddDependencyOfConstant) { if (IsMlirLoweringEnabled()) { // This test generates invalid HLO. The after-all op only takes tokens. From bab3a715d1232505ae9b1e5d63468ef0c4c63ac6 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 23 Jul 2024 03:35:40 -0700 Subject: [PATCH 081/376] [xla:cpu] Disable HLO execution profile test on CPU backend Similar to GPU backend we can't generate profiling data that is anywhere close to reality, so we prefer to disable the test for now. PiperOrigin-RevId: 655102005 --- xla/tests/xla_hlo_profile_test.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xla/tests/xla_hlo_profile_test.cc b/xla/tests/xla_hlo_profile_test.cc index 8d19996c1a9cbc..2436635dea5ef0 100644 --- a/xla/tests/xla_hlo_profile_test.cc +++ b/xla/tests/xla_hlo_profile_test.cc @@ -188,7 +188,8 @@ void ExecuteAndFetchProfile(std::string* profile_output, LocalClient* client, XLA_VLOG_LINES(4, *profile_output); } -XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileSingleComputation)) { +XLA_TEST_F(HloProfileTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(ProfileSingleComputation))) { const int64_t m = 32, k = 32, n = 32; Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k}); @@ -267,7 +268,8 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileSingleComputation)) { EXPECT_TRUE(HasTrops(tanh_profile)); } -XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { +XLA_TEST_F(HloProfileTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(ProfileWhileComputation))) { const int64_t size = 32; Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size}); Shape while_result_shape = From 4089aadec1b6186c3d537c5c3d4eba4ba89dd8d5 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 23 Jul 2024 03:45:15 -0700 Subject: [PATCH 082/376] [XLA:GPU] Only compute tile offset indexing maps for instructions when it is needed. Computing tile offset indexing maps is very expensive and not always necessary. When we're at Cost Model/Tiling stage, we only use the indexing map to deduplicate instruction. We can predict if we'll need an indexing map for a particular instruction by computation and comparing a parts of the hash. PiperOrigin-RevId: 655103934 --- .../fusions/triton/triton_fusion_emitter.cc | 34 ++++++--- .../fusions/triton/triton_fusion_emitter.h | 2 +- .../triton_fusion_emitter_mem_utils_test.cc | 5 +- .../triton/triton_fusion_emitter_stub.cc | 2 +- xla/service/gpu/model/BUILD | 2 + .../gpu/model/symbolic_tile_analysis.cc | 75 ++++++++++++++++--- .../gpu/model/symbolic_tile_analysis.h | 8 +- .../gpu/model/symbolic_tile_analysis_test.cc | 45 +++++++---- .../gpu/model/tiled_hlo_instruction.cc | 12 ++- xla/service/gpu/model/tiled_hlo_instruction.h | 43 +++++++---- .../gpu/model/tiled_hlo_instruction_test.cc | 11 ++- 11 files changed, 174 insertions(+), 65 deletions(-) diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 430fbca28a4389..e21483462afeea 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -906,9 +906,10 @@ absl::StatusOr EmitTiledHloInstruction( const HloInstruction* hlo = tiled_hlo.hlo(); if (fusion->IsUserOf(tiled_hlo.hlo())) { - auto make_tensor = ir_emitter_triton_internal::CreateMakeTensorPtrOp( - b, tile_multi_index, tiled_hlo, - fn.getArgument(fusion->operand_index(hlo))); + TF_ASSIGN_OR_RETURN(auto make_tensor, + ir_emitter_triton_internal::CreateMakeTensorPtrOp( + b, tile_multi_index, tiled_hlo, + fn.getArgument(fusion->operand_index(hlo)))); return EmitParameterLoad(b, make_tensor.op, make_tensor.boundary_checks); } @@ -2498,8 +2499,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, // Computes the base pointer offset for the given tile multi-index and hlo shape // taking into account the physical layout of the hlo buffer. -Value ComputeBasePtrOffset(ImplicitLocOpBuilder b, ValueRange tile_multi_index, - const TiledHloInstruction& tiled_hlo) { +absl::StatusOr ComputeBasePtrOffset( + ImplicitLocOpBuilder b, ValueRange tile_multi_index, + const TiledHloInstruction& tiled_hlo) { const Shape& shape = tiled_hlo.hlo()->shape(); Shape linear_shape = ShapeUtil::MakeShape(shape.element_type(), {ShapeUtil::ElementsIn(shape)}); @@ -2507,8 +2509,12 @@ Value ComputeBasePtrOffset(ImplicitLocOpBuilder b, ValueRange tile_multi_index, // Bitcast map gives an indexing map from linear index to the parameter shape // index respecting physical layout of the memory. auto bitcast_map = GetBitcastMap(shape, linear_shape, b.getContext()); + + TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, + tiled_hlo.tile_offsets_indexing()); + auto compose_indexing_maps = - ComposeIndexingMaps(tiled_hlo.tile_offsets_indexing(), bitcast_map); + ComposeIndexingMaps(tile_offsets_indexing, bitcast_map); compose_indexing_maps.Simplify(); return b.create( @@ -2541,7 +2547,7 @@ SmallVector ComputeDelinearizedTileIndex( /*symbols=*/{}, b); } -MakeTensorPtrOpAndBoundaryChecks CreateMakeTensorPtrOp( +absl::StatusOr CreateMakeTensorPtrOp( ImplicitLocOpBuilder& b, ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo, Value argument_block) { llvm::SmallVector sizes; @@ -2601,7 +2607,8 @@ MakeTensorPtrOpAndBoundaryChecks CreateMakeTensorPtrOp( // Manually compute pointer offset to avoid materialized fully parallel // dimensions in the tile. Current codegen tried to avoid size-1 dims. - Value ptr_offset = ComputeBasePtrOffset(b, tile_multi_index, tiled_hlo); + TF_ASSIGN_OR_RETURN(Value ptr_offset, + ComputeBasePtrOffset(b, tile_multi_index, tiled_hlo)); auto tile_ptr = AddPtr(b, argument_block, ptr_offset); return MakeTensorPtrOpAndBoundaryChecks{b.create( @@ -2640,7 +2647,9 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, symbolic_tile_analysis.ComputeTiledHloInstructions( - block_level_parameters.output_tile_sizes)); + block_level_parameters.output_tile_sizes, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); SmallVector tile_multi_index = ir_emitter_triton_internal::ComputeDelinearizedTileIndex( @@ -2652,9 +2661,10 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, tiled_hlo_computation, fn, tile_multi_index)); const auto& tiled_hlo = *tiled_hlo_computation.GetRoot(); - auto make_tensor = ir_emitter_triton_internal::CreateMakeTensorPtrOp( - b, tile_multi_index, tiled_hlo, - fn.getArgument(computation->num_parameters())); + TF_ASSIGN_OR_RETURN(auto make_tensor, + ir_emitter_triton_internal::CreateMakeTensorPtrOp( + b, tile_multi_index, tiled_hlo, + fn.getArgument(computation->num_parameters()))); b.create(make_tensor.op, result, make_tensor.boundary_checks, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.h b/xla/service/gpu/fusions/triton/triton_fusion_emitter.h index eb9644808fc6b6..fe133d88f44e45 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.h +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.h @@ -166,7 +166,7 @@ struct MakeTensorPtrOpAndBoundaryChecks { llvm::SmallVector boundary_checks; }; -MakeTensorPtrOpAndBoundaryChecks CreateMakeTensorPtrOp( +absl::StatusOr CreateMakeTensorPtrOp( mlir::ImplicitLocOpBuilder& b, mlir::ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo, mlir::Value argument_block); } // namespace ir_emitter_triton_internal diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc index 3b089114c34da1..44611bda590dfa 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc @@ -121,7 +121,8 @@ TritonMakeTensorPtrTest::CreateAndTileParameterHloInstruction( auto tiled_hlo_computation_or = symbolic_tile_analysis.ComputeTiledHloInstructions( - tile_sizes, /*constraints_are_known_satisfied=*/true); + tile_sizes, /*constraints_are_known_satisfied=*/true, + /*compute_all_tile_offset_indexing_maps=*/true); TF_EXPECT_OK(tiled_hlo_computation_or.status()); return std::make_pair(std::move(verified_hlo_module), *std::move(tiled_hlo_computation_or)); @@ -175,7 +176,7 @@ TritonMakeTensorPtrTest::CreateTestTensorPtr( return std::make_pair( std::move(triton_module), - ir_emitter_triton_internal::CreateMakeTensorPtrOp( + *ir_emitter_triton_internal::CreateMakeTensorPtrOp( b, tile_multi_index, *tiled_hlo, fn.getArgument(0))); } diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc index 89ceeb038ae025..33c1e0666dd90d 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc @@ -121,7 +121,7 @@ std::string GetLibdevicePath(const HloModuleConfig& hlo_config, namespace ir_emitter_triton_internal { -MakeTensorPtrOpAndBoundaryChecks CreateMakeTensorPtrOp( +absl::StatusOr CreateMakeTensorPtrOp( mlir::ImplicitLocOpBuilder& b, mlir::ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo, mlir::Value argument_block) { return MakeTensorPtrOpAndBoundaryChecks(); diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index c9346ce36abf88..0a884ae230160c 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -633,6 +633,7 @@ cc_library( ":indexing_analysis", "//xla:util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -710,6 +711,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", diff --git a/xla/service/gpu/model/symbolic_tile_analysis.cc b/xla/service/gpu/model/symbolic_tile_analysis.cc index ccf2fd642fad9f..ea69eadb079bac 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" @@ -40,6 +41,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -198,7 +200,9 @@ class OrderedUniquePtrValueHashSet { }; struct PtrEqual { - bool operator()(const T* lhs, const T* rhs) const { return *lhs == *rhs; } + bool operator()(const T* lhs, const T* rhs) const { + return lhs == rhs || *lhs == *rhs; + } }; // Stores non-owning pointers to the elements in the set. Elements are @@ -428,7 +432,8 @@ absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( absl::StatusOr SymbolicTileAnalysis::ComputeTiledHloInstructions( absl::Span tile_parameters, - bool constraints_are_known_satisfied) const { + bool constraints_are_known_satisfied, + bool compute_all_tile_offset_indexing_maps) const { if (!constraints_are_known_satisfied) { TF_ASSIGN_OR_RETURN(bool constraints_are_satisfied, ParametersSatisfyConstraints(tile_parameters)); @@ -439,6 +444,48 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( } } + // Offset indexing is needed to emit loads/stores and to deduplicate + // instructions. In some cases, for example in Cost Model, we need to only + // deduplicate instructions. + // + // Computing tile offset indexing maps is very expensive. This is a + // performance optimization to avoid computing tile offset indexing maps for + // instructions that are not needed. + // + // Tile offset indexing is only needed when one HLO instruction has no + // operands and multiple tiles have exactly same sizes and strides. We skip + // strides in the heuristic below, because they are rarely different. + // + // Using `compute_all_tile_offset_indexing_maps` will force to compute tile + // offset indexing maps for all instructions. + llvm::SmallPtrSet parameters_with_offset_indexing; + absl::flat_hash_map> + tile_sizes_map; + if (!compute_all_tile_offset_indexing_maps) { + absl::flat_hash_set hashes; + for (const std::unique_ptr& + symbolic_tiled_hlo : symbolic_tiled_hlo_instructions_) { + if (!symbolic_tiled_hlo->operands().empty()) { + continue; + } + + llvm::SmallVector tile_sizes = + symbolic_tiled_hlo->TileSizes(tile_parameters); + size_t hash_value = absl::HashOf(symbolic_tiled_hlo->hlo(), + absl::Span(tile_sizes)); + tile_sizes_map.emplace(symbolic_tiled_hlo.get(), std::move(tile_sizes)); + + auto [it, inserted] = hashes.insert(hash_value); + // Two SymbolicTiledHloInstructions have identical hash when looking only + // at HLO instruction pointer and tile sizes. We need to compute tile + // offset indexing maps for all tiles of this HLO instruction. + if (!inserted) { + parameters_with_offset_indexing.insert(symbolic_tiled_hlo->hlo()); + } + } + } + OutputTilingInfo output_tiling_info = ComputeOutputTilingInfo( GetRoot()->hlo()->shape().dimensions(), tile_parameters, context_); @@ -452,16 +499,26 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( for (const std::unique_ptr& symbolic_tiled_hlo : symbolic_tiled_hlo_instructions_) { - llvm::SmallVector tile_sizes = - symbolic_tiled_hlo->TileSizes(tile_parameters); + llvm::SmallVector tile_sizes; + auto it = tile_sizes_map.find(symbolic_tiled_hlo.get()); + if (it != tile_sizes_map.end()) { + tile_sizes = it->second; + } else { + tile_sizes = symbolic_tiled_hlo->TileSizes(tile_parameters); + } + llvm::SmallVector tile_strides = symbolic_tiled_hlo->TileStrides(tile_parameters); - TF_ASSIGN_OR_RETURN( - IndexingMap tile_offset_indexing, - ComputeTileOffsetIndexing( - *symbolic_tiled_hlo, output_tiling_info.output_tile_offset_indexing, - context_)); + std::optional tile_offset_indexing; + if (compute_all_tile_offset_indexing_maps || + parameters_with_offset_indexing.contains(symbolic_tiled_hlo->hlo())) { + TF_ASSIGN_OR_RETURN( + tile_offset_indexing, + ComputeTileOffsetIndexing( + *symbolic_tiled_hlo, + output_tiling_info.output_tile_offset_indexing, context_)); + } llvm::SmallVector operands; for (const SymbolicTiledHloInstruction* operand : diff --git a/xla/service/gpu/model/symbolic_tile_analysis.h b/xla/service/gpu/model/symbolic_tile_analysis.h index 8cf04f0a6fafc0..df56d2325dd641 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/xla/service/gpu/model/symbolic_tile_analysis.h @@ -69,9 +69,15 @@ class SymbolicTileAnalysis { // By default, `ComputetiledHloInstructions` performs a check that the // constraints are satisfied by the chosen tiled parameters. Setting // `constraints_are_known_satisfied` to true bypasses this check. + // + // If `compute_all_tile_offset_indexing_maps == true`, all + // TiledHloInstructions will have tile offset indexing maps set. Otherwise, + // the indexing maps will be set only for instructions that have equal hash to + // deduplicate them. absl::StatusOr ComputeTiledHloInstructions( absl::Span tile_parameters, - bool constraints_are_known_satisfied = false) const; + bool constraints_are_known_satisfied = false, + bool compute_all_tile_offset_indexing_maps = false) const; // Returns the tiled root instruction. const SymbolicTiledHloInstruction* GetRoot() const { diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc index a0e964ff77f4f4..0a359415a153d0 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -65,8 +65,9 @@ MATCHER_P3(MatchTiledHloInstructionImpl, tile_sizes, tile_strides, result_listener) && ExplainMatchResult(ElementsAreArray(tile_strides), arg.tile_strides(), result_listener) && - ExplainMatchResult(MatchIndexingMap(tile_offsets_indexing), - arg.tile_offsets_indexing(), result_listener); + ExplainMatchResult( + IsOkAndHolds(MatchIndexingMap(tile_offsets_indexing)), + arg.tile_offsets_indexing(), result_listener); } Matcher MatchTiledHloInstruction( @@ -122,13 +123,17 @@ ENTRY main { std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); - TF_ASSERT_OK_AND_ASSIGN( - TiledHloComputation tiled_hlo_computation, - analysis->ComputeTiledHloInstructions(/*tile_parameters=*/{1, 10})); + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{1, 10}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); const TiledHloInstruction* root = tiled_hlo_computation.GetRoot(); - EXPECT_THAT(root->tile_offsets_indexing(), MatchIndexingMap(R"( + EXPECT_THAT(*root, MatchTiledHloInstruction(/*tile_sizes=*/{1, 10}, + /*tile_strides=*/{1, 1}, + /*tile_offsets_indexing=*/R"( (d0, d1) -> (d0, d1 * 10) domain: d0 in [0, 2) @@ -266,9 +271,11 @@ ENTRY main { std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); - TF_ASSERT_OK_AND_ASSIGN( - TiledHloComputation tiled_hlo_computation, - analysis->ComputeTiledHloInstructions(/*tile_parameters=*/{2, 4, 2})); + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{2, 4, 2}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); const TiledHloInstruction* root = tiled_hlo_computation.GetRoot(); @@ -311,9 +318,11 @@ ENTRY main { std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); - TF_ASSERT_OK_AND_ASSIGN( - TiledHloComputation tiled_hlo_computation, - analysis->ComputeTiledHloInstructions(/*tile_parameters=*/{2, 2})); + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{2, 2}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); const TiledHloInstruction* root = tiled_hlo_computation.GetRoot(); const TiledHloInstruction* p0_from_slice0 = root->operand(0)->operand(0); @@ -779,9 +788,11 @@ ENTRY main { std::optional analysis = TryAnalyzeModule(module.get()); ASSERT_TRUE(analysis.has_value()); - TF_ASSERT_OK_AND_ASSIGN( - TiledHloComputation tiled_hlo_computation, - analysis->ComputeTiledHloInstructions(/*tile_parameters=*/{1, 1})); + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{1, 1}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); EXPECT_THAT(*tiled_hlo_computation.GetRoot(), MatchTiledHloInstruction( @@ -831,7 +842,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN( TiledHloComputation tiled_hlo_computation, - analysis->ComputeTiledHloInstructions(/*tile_parameters=*/{1, 1})); + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{1, 1}, /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); const TiledHloInstruction* dynamic_slice = tiled_hlo_computation.GetRoot()->operand(0); diff --git a/xla/service/gpu/model/tiled_hlo_instruction.cc b/xla/service/gpu/model/tiled_hlo_instruction.cc index 4a2543970a1df1..707bbe8de171b7 100644 --- a/xla/service/gpu/model/tiled_hlo_instruction.cc +++ b/xla/service/gpu/model/tiled_hlo_instruction.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "absl/memory/memory.h" @@ -39,7 +40,7 @@ TiledHloInstruction::Create( llvm::SmallVector operands, llvm::SmallVector tile_sizes, llvm::SmallVector tile_strides, - IndexingMap tile_offsets_indexing) { + std::optional tile_offsets_indexing) { int rank = hlo->shape().rank(); if (tile_sizes.size() != rank) { @@ -56,11 +57,12 @@ TiledHloInstruction::Create( tile_strides.size(), ", hlo = ", hlo->ToString())); } - if (tile_offsets_indexing.GetAffineMap().getNumResults() != rank) { + if (tile_offsets_indexing.has_value() && + tile_offsets_indexing->GetAffineMap().getNumResults() != rank) { return absl::InvalidArgumentError(absl::StrFormat( "tile_offsets_indexing must have the same number of results as the " "rank of the hlo shape. tile_offsets_indexing = %s, hlo = %s", - tile_offsets_indexing.ToString(), hlo->ToString())); + tile_offsets_indexing->ToString(), hlo->ToString())); } return absl::WrapUnique(new TiledHloInstruction( @@ -73,7 +75,9 @@ std::string TiledHloInstruction::ToString() const { ss << "\thlo: " << hlo_->ToString() << "\n"; ss << "\ttile_sizes: (" << absl::StrJoin(tile_sizes_, ", ") << ")\n"; ss << "\ttile_strides: (" << absl::StrJoin(tile_strides_, ", ") << ")\n"; - ss << "\ttile_offsets_indexing: " << tile_offsets_indexing_; + ss << "\ttile_offsets_indexing: " + << (tile_offsets_indexing_.has_value() ? tile_offsets_indexing_->ToString() + : "nullopt"); return ss.str(); } diff --git a/xla/service/gpu/model/tiled_hlo_instruction.h b/xla/service/gpu/model/tiled_hlo_instruction.h index 86c7969d06560e..5ac17dd84aa833 100644 --- a/xla/service/gpu/model/tiled_hlo_instruction.h +++ b/xla/service/gpu/model/tiled_hlo_instruction.h @@ -18,10 +18,13 @@ limitations under the License. #include #include +#include #include #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" @@ -52,8 +55,7 @@ class TiledHloInstruction { llvm::SmallVector operands, llvm::SmallVector tile_sizes, llvm::SmallVector tile_strides, - - IndexingMap tile_offsets_indexing); + std::optional tile_offsets_indexing); // Returns the original HLO instruction. const HloInstruction* hlo() const { return hlo_; } @@ -81,8 +83,16 @@ class TiledHloInstruction { // a form of `(d0, d1, ...) -> (tile_offset0, tile_offset1, ...)`. The number // of input dimensions is equal to the rank of output tile of the computation. // The number of tile offsets is equal to the rank of the tiled hlo. - const IndexingMap& tile_offsets_indexing() const { - return tile_offsets_indexing_; + // + // The indexing map is not computed by default. + absl::StatusOr tile_offsets_indexing() const { + if (!tile_offsets_indexing_.has_value()) { + return absl::FailedPreconditionError( + "tile_offsets_indexing was not computed. It is likely that " + "`compute_all_tile_offset_indexing_maps` should be set to true in " + "`SymbolicTileAnalysis::ComputeTiledHloInstructions`."); + } + return *tile_offsets_indexing_; } std::string ToString() const; @@ -98,7 +108,7 @@ class TiledHloInstruction { llvm::SmallVector operands, llvm::SmallVector tile_sizes, llvm::SmallVector tile_strides, - IndexingMap tile_offsets_indexing) + std::optional tile_offsets_indexing) : hlo_(hlo), operands_(std::move(operands)), tile_sizes_(std::move(tile_sizes)), @@ -115,16 +125,24 @@ class TiledHloInstruction { llvm::SmallVector tile_sizes_; llvm::SmallVector tile_strides_; - // Indexing map for tile offsets. - IndexingMap tile_offsets_indexing_; + // See comment for `tile_offsets_indexing()`. + std::optional tile_offsets_indexing_; }; inline bool operator==(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs) { - return lhs.hlo() == rhs.hlo() && lhs.tile_sizes() == rhs.tile_sizes() && - lhs.tile_strides() == rhs.tile_strides() && - lhs.operands() == rhs.operands() && - lhs.tile_offsets_indexing() == rhs.tile_offsets_indexing(); + if (lhs.hlo() != rhs.hlo() || lhs.tile_sizes() != rhs.tile_sizes() || + lhs.tile_strides() != rhs.tile_strides()) { + return false; + } + + if (lhs.operands().empty() && rhs.operands().empty()) { + // Tile offsets indexing is guaranteed to be computed only if tile sizes are + // different and the instruction has no operands. + return lhs.tile_offsets_indexing() == rhs.tile_offsets_indexing(); + } + + return lhs.operands() == rhs.operands(); } inline bool operator!=(const TiledHloInstruction& lhs, @@ -142,8 +160,7 @@ H AbslHashValue(H h, const TiledHloInstruction& tiled_hlo_instruction) { absl::Span(tiled_hlo_instruction.tile_sizes()), absl::Span(tiled_hlo_instruction.tile_strides()), absl::Span( - tiled_hlo_instruction.operands()), - tiled_hlo_instruction.tile_offsets_indexing()); + tiled_hlo_instruction.operands())); } } // namespace gpu diff --git a/xla/service/gpu/model/tiled_hlo_instruction_test.cc b/xla/service/gpu/model/tiled_hlo_instruction_test.cc index d49a666a89e462..75e19273d3641f 100644 --- a/xla/service/gpu/model/tiled_hlo_instruction_test.cc +++ b/xla/service/gpu/model/tiled_hlo_instruction_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -43,7 +42,7 @@ TEST_F(TiledHloInstructionTest, TileSizesAndStridesShouldMatchHloShapeRank) { /*parameter_number=*/0, ShapeUtil::MakeShape(PrimitiveType::F32, {32, 64}), "p0"); - IndexingMap block_id_to_tile_offsets_indexing = IndexingMap::FromTensorSizes( + IndexingMap tile_offsets_indexing = IndexingMap::FromTensorSizes( ParseAffineMap("(d0) -> (d0 floordiv 16, (d0 mod 16) * 16)", &mlir_context_), /*dim_upper_bounds=*/{8}, @@ -51,14 +50,14 @@ TEST_F(TiledHloInstructionTest, TileSizesAndStridesShouldMatchHloShapeRank) { EXPECT_THAT(TiledHloInstruction::Create( hlo.get(), /*operands=*/{}, /*tile_sizes=*/{16}, - /*tile_strides=*/{1, 1}, block_id_to_tile_offsets_indexing) + /*tile_strides=*/{1, 1}, tile_offsets_indexing) .status() .message(), HasSubstr("Number of tile sizes must be equal to the rank")); EXPECT_THAT(TiledHloInstruction::Create( hlo.get(), /*operands=*/{}, /*tile_sizes=*/{16, 16}, - /*tile_strides=*/{1, 1, 1}, block_id_to_tile_offsets_indexing) + /*tile_strides=*/{1, 1, 1}, tile_offsets_indexing) .status() .message(), HasSubstr("Number of tile strides must be equal to the rank")); @@ -70,7 +69,7 @@ TEST_F(TiledHloInstructionTest, /*parameter_number=*/0, ShapeUtil::MakeShape(PrimitiveType::F32, {32, 64}), "p0"); - IndexingMap tile_offsets_indexing1 = IndexingMap::FromTensorSizes( + IndexingMap tile_offsets_indexing = IndexingMap::FromTensorSizes( ParseAffineMap("(d0, d1) -> (2 * d0)", &mlir_context_), /*dim_upper_bounds=*/{2, 4}, /*symbol_upper_bounds=*/{}); @@ -78,7 +77,7 @@ TEST_F(TiledHloInstructionTest, EXPECT_THAT( TiledHloInstruction::Create( hlo.get(), /*operands=*/{}, /*tile_sizes=*/{16, 16}, - /*tile_strides=*/{1, 1}, tile_offsets_indexing1) + /*tile_strides=*/{1, 1}, tile_offsets_indexing) .status() .message(), HasSubstr( From 46f6c4764bd774a14d5cfe7f36e40d44f9147c88 Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Tue, 23 Jul 2024 04:18:10 -0700 Subject: [PATCH 083/376] PR #15163: [NVIDIA GPU] Skip processing for trip count 1 in loop double buffer unrolling Imported from GitHub PR https://github.com/openxla/xla/pull/15163 A minor improvement to loop double buffer unrolling pass to skip processing for loops with trip count =1 Copybara import of the project: -- f6bccd4612dde10cd020141f804523a75d9c84a2 by TJ Xu : Skip processing for trip count 1 -- 19514326f83ee91f712ffbd74cec90b19db77df7 by TJ Xu : added a test case Merging this change closes #15163 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15163 from Tixxx:tixxx/double_buffer_skip-1 19514326f83ee91f712ffbd74cec90b19db77df7 PiperOrigin-RevId: 655111494 --- .../gpu/double_buffer_loop_unrolling.cc | 2 +- .../gpu/double_buffer_loop_unrolling_test.cc | 40 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/double_buffer_loop_unrolling.cc b/xla/service/gpu/double_buffer_loop_unrolling.cc index 1901aee38e50fc..9cd9113c176f83 100644 --- a/xla/service/gpu/double_buffer_loop_unrolling.cc +++ b/xla/service/gpu/double_buffer_loop_unrolling.cc @@ -533,7 +533,7 @@ absl::StatusOr DoubleBufferLoopUnrolling::Run( for (HloInstruction* while_instr : while_instrs) { TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config, while_instr->backend_config()); - if (!config.has_known_trip_count()) { + if (!config.has_known_trip_count() || config.known_trip_count().n() == 1) { VLOG(2) << while_instr->ToString() << " doesn't have exact trip count, skipping loop unrolling " "for now"; diff --git a/xla/service/gpu/double_buffer_loop_unrolling_test.cc b/xla/service/gpu/double_buffer_loop_unrolling_test.cc index e1a2786728b88d..8fed3192b08598 100644 --- a/xla/service/gpu/double_buffer_loop_unrolling_test.cc +++ b/xla/service/gpu/double_buffer_loop_unrolling_test.cc @@ -1198,6 +1198,46 @@ ENTRY main { )")); } +TEST_F(GpuLoopDoubleBufferTransformerTest, + WhileLoopWithTripCount1ShouldBeSkipped) { + const char* const kModuleString = R"( +HloModule loop_unrolling_skipped +condition_nested { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + trip_count = s32[] constant(0) + ROOT done = pred[] compare(cond, trip_count), direction=LT +} +body_nested { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + one = s32[] constant(1) + cond_plus_1 = s32[] add(cond, one) + ROOT output = (s32[]) tuple(cond_plus_1) +} +condition { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + trip_count = s32[] constant(0) + ROOT done = pred[] compare(cond, trip_count), direction=LT +} +body { + input_tuple = (s32[]) parameter(0) + ROOT output = (s32[]) while(input_tuple), condition=condition_nested, body=body_nested, backend_config={"known_trip_count":{"n":"1"}} +} +ENTRY main { + param_0 = (s32[]) parameter(0) + ROOT while = (s32[]) while(param_0), condition=condition, body=body, backend_config={"known_trip_count":{"n":"1"}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + DoubleBufferLoopUnrolling double_buffer( + DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll); + // The processing of the loop should be completely skipped. + EXPECT_THAT(double_buffer.Run(module.get()), IsOkAndHolds(false)); +} + } // namespace } // namespace gpu } // namespace xla From af9882015f8668d370ff230e86897709cf9d2aee Mon Sep 17 00:00:00 2001 From: Shanbin Ke Date: Tue, 23 Jul 2024 04:26:23 -0700 Subject: [PATCH 084/376] PR #15149: [XLA:GPU] print cudnn frontend check_support error message Imported from GitHub PR https://github.com/openxla/xla/pull/15149 Copybara import of the project: -- 0e5c2fa7e539fe97f7c5b8aadfc3fedd38ad4232 by cjkkkk : init -- 8565f356e82549e05b9d46e8a92f3e49a4370f3b by cjkkkk : update function signature Merging this change closes #15149 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15149 from Cjkkkk:print_cudnn_check_support_message 8565f356e82549e05b9d46e8a92f3e49a4370f3b PiperOrigin-RevId: 655113338 --- xla/service/gpu/cudnn_fusion_compiler.cc | 5 +---- xla/stream_executor/cuda/cuda_dnn.cc | 19 +++++-------------- xla/stream_executor/cuda/cuda_dnn.h | 2 +- xla/stream_executor/dnn.h | 2 +- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/xla/service/gpu/cudnn_fusion_compiler.cc b/xla/service/gpu/cudnn_fusion_compiler.cc index 01fd722438f333..f9ae751ef6949b 100644 --- a/xla/service/gpu/cudnn_fusion_compiler.cc +++ b/xla/service/gpu/cudnn_fusion_compiler.cc @@ -600,10 +600,7 @@ absl::StatusOr PrepareGraph( if (!graph.has_value()) { return absl::InternalError("Construction of cuDNN graph failed."); } - TF_ASSIGN_OR_RETURN(bool supported, graph->Prepare(dnn_support)); - if (!supported) { - return absl::InternalError("cuDNN graph is not supported."); - } + TF_RETURN_IF_ERROR(graph->Prepare(dnn_support)); return *graph; } diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index b7dd5fdd4085c4..3898d1e3ea002a 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -5201,10 +5201,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_uid(CudnnfMHAUid::P_ID); } CudnnGraph cudnnGraph(std::move(graph)); - TF_ASSIGN_OR_RETURN(bool supported, cudnnGraph.Prepare(dnn_support)); - if (!supported) { - return absl::InternalError("cuDNN graph is not supported."); - } + TF_RETURN_IF_ERROR(cudnnGraph.Prepare(dnn_support)); TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, std::nullopt)); if (VLOG_IS_ON(4)) { @@ -5420,10 +5417,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_data_type(ioDataType); CudnnGraph cudnnGraph(std::move(graph)); - TF_ASSIGN_OR_RETURN(bool supported, cudnnGraph.Prepare(dnn_support)); - if (!supported) { - return absl::InternalError("cuDNN graph is not supported."); - } + TF_RETURN_IF_ERROR(cudnnGraph.Prepare(dnn_support)); TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, std::nullopt)); if (VLOG_IS_ON(4)) { @@ -8374,18 +8368,15 @@ absl::StatusOr> CudnnSupport::DeserializeGraph( return std::make_unique(std::move(graph)); } -absl::StatusOr CudnnGraph::Prepare(dnn::DnnSupport& dnn_support) { +absl::Status CudnnGraph::Prepare(dnn::DnnSupport& dnn_support) { const CudnnSupport& cudnn_support = static_cast(dnn_support); TF_ASSIGN_OR_RETURN(auto cudnn, cudnn_support.cudnn_->GetLocalHandle()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.validate()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.build_operation_graph(cudnn->handle())); RETURN_IF_CUDNN_FRONTEND_ERROR( graph_.create_execution_plans({cudnn_frontend::HeurMode_t::A})); - if (auto result = graph_.check_support(cudnn->handle()); result.is_bad()) { - VLOG(3) << result.get_message(); - return false; - } - return true; + RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.check_support(cudnn->handle())); + return absl::OkStatus(); } absl::Status CudnnGraph::Build(dnn::DnnSupport& dnn_support, diff --git a/xla/stream_executor/cuda/cuda_dnn.h b/xla/stream_executor/cuda/cuda_dnn.h index 6967f43dff5a5f..52086938d5a30f 100644 --- a/xla/stream_executor/cuda/cuda_dnn.h +++ b/xla/stream_executor/cuda/cuda_dnn.h @@ -60,7 +60,7 @@ class CudnnGraph : public dnn::DnnGraph { explicit CudnnGraph(cudnn_frontend::graph::Graph&& graph) : graph_(std::move(graph)) {} // Prepares a graph and checks whether it is generally supported. - absl::StatusOr Prepare(dnn::DnnSupport&) override; + absl::Status Prepare(dnn::DnnSupport&) override; // Builds single plan of the graph with given ID. absl::Status Build(dnn::DnnSupport&, std::optional plan_id) override; // Builds all the plans diff --git a/xla/stream_executor/dnn.h b/xla/stream_executor/dnn.h index 8c017a3756be4b..72f4603b4d3a04 100644 --- a/xla/stream_executor/dnn.h +++ b/xla/stream_executor/dnn.h @@ -1261,7 +1261,7 @@ class DnnGraph { // anything else unexpected), // false on expected ones (graph is valid but not supported), // true on success. - virtual absl::StatusOr Prepare(DnnSupport&) = 0; + virtual absl::Status Prepare(DnnSupport&) = 0; virtual absl::Status Build(DnnSupport&, std::optional plan_id) = 0; virtual absl::Status Execute(Stream& stream, absl::Span operands) const = 0; From ddba514b5bdfbc25e73e37071f8ad80930ca2c5d Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Tue, 23 Jul 2024 05:25:14 -0700 Subject: [PATCH 085/376] PR #15210: Require Matching Types in Layer Norm Fusion Imported from GitHub PR https://github.com/openxla/xla/pull/15210 Disables the fusion of layer norm patterns when the input and output types are not the same. Copybara import of the project: -- 4872f4413abbbc5e330a36f1c2ce5a5264a61fd7 by Philipp Hack : Disables layer norm fusion when the input and output types differ. Merging this change closes #15210 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15210 from philipphack:u_layer_type_mismatch_xla 4872f4413abbbc5e330a36f1c2ce5a5264a61fd7 PiperOrigin-RevId: 655126293 --- xla/service/gpu/cudnn_norm_rewriter.cc | 9 ++-- xla/service/gpu/cudnn_norm_rewriter_test.cc | 58 +++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/xla/service/gpu/cudnn_norm_rewriter.cc b/xla/service/gpu/cudnn_norm_rewriter.cc index 25b976cc380911..5e78f4864ec334 100644 --- a/xla/service/gpu/cudnn_norm_rewriter.cc +++ b/xla/service/gpu/cudnn_norm_rewriter.cc @@ -948,10 +948,11 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - // Verify the element types. The types and shapes of the scale and bias - // must match. - if (!CompatibleElementType(x.Instr()) || !CompatibleElementType(instr) || - !CompatibleElementType(scale) || !CompatibleElementType(bias) || + // Verify the element types. The element types of input and output and the + // shapes of scale and bias must match. + if (!CompatibleElementType(instr) || !CompatibleElementType(scale) || + !CompatibleElementType(bias) || + !ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) || !ShapeUtil::Equal(scale->shape(), bias->shape())) { VLOG(1) << "Layer norm input types or shapes not supported."; return absl::OkStatus(); diff --git a/xla/service/gpu/cudnn_norm_rewriter_test.cc b/xla/service/gpu/cudnn_norm_rewriter_test.cc index de0791d1fa5929..754563a535c23b 100644 --- a/xla/service/gpu/cudnn_norm_rewriter_test.cc +++ b/xla/service/gpu/cudnn_norm_rewriter_test.cc @@ -599,6 +599,64 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) { TestNorm(hlo_text, optimized_hlo); } +TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f16[2,4,6,8] parameter(0) + input_f32 = f32[2,4,6,8] convert(input) + input_square = f32[2,4,6,8] multiply(input_f32, input_f32) + c0 = f32[] constant(0) + input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply + r_nelems = f32[] constant(0.125) + r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast) + input_sum = f32[2,4,6] reduce(input_f32, c0), dimensions={3}, to_apply=apply + input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast) + input_mean_square = f32[2,4,6] multiply(input_mean, input_mean) + variance = f32[2,4,6] subtract(input_square_mean, input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast) + norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2} + input_center = f32[2,4,6,8] subtract(input_f32, input_mean_bcast) + norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center) + scale = f32[8] parameter(1) + scale_bcast = f32[2,4,6,8] broadcast(scale), dimensions={3} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f32[8] parameter(2) + bias_bcast = f32[2,4,6,8] broadcast(bias), dimensions={3} + ROOT out = f32[2,4,6,8] add(norm_scale, bias_bcast) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f16[2,4,6,8], {{.*}}: f32[8], {{.*}}: f32[8]) -> f32[2,4,6,8] { +; CHECK-NOT: custom_call_target="__cudnn$norm" + )"; + + TestNorm(hlo_text, optimized_hlo); +} + TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) { #if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; From 4e56ae4cafd5a8a2849325f8e2ee0607554b54a2 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Tue, 23 Jul 2024 05:31:15 -0700 Subject: [PATCH 086/376] #sdy Initial set of changes to allow for lowering to the Shardy dialect. The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP StableHLO 3. StableHLO with Shardy propagation 4. StableHLO with Shardy partitioning 5. StableHLO -> HLO 6. XLA optimizations The following test: ```py def test_sdy_lowering(self): mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @partial(jax.jit, out_shardings=s) def f(x): return x * 2 print(f.lower(arr).as_text()) ``` outputs: ``` module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <"x"=4, "y"=2> func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { %c = stablehlo.constant dense<2> : tensor %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x2xi64> %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64> return %1 : tensor<8x2xi64> } } ``` Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future. PiperOrigin-RevId: 655127611 --- xla/python/xla_client.py | 2 +- xla/python/xla_compiler.cc | 5 ++++- xla/python/xla_extension/__init__.pyi | 2 ++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index e63b058f5ab24c..b0f0264162eb15 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 277 +_version = 278 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_compiler.cc b/xla/python/xla_compiler.cc index 3be39c7e43eb65..2259083a2da478 100644 --- a/xla/python/xla_compiler.cc +++ b/xla/python/xla_compiler.cc @@ -1199,7 +1199,10 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { &DebugOptions::xla_gpu_dump_autotune_logs_to, [](DebugOptions* self, std::string value) { self->set_xla_gpu_dump_autotune_logs_to(value); - }); + }) + // TODO(b/352486192): Move this to `ExecutableBuildOptions`. + .def_prop_rw("xla_use_shardy", &DebugOptions::xla_use_shardy, + &DebugOptions::set_xla_use_shardy); nb::class_(m, "ExecutableBuildOptions") .def(nb::init<>()) diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 2505d9c01ea737..e19bf8546491ab 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -318,6 +318,8 @@ class DebugOptions: xla_gpu_dump_autotune_results_to: str xla_gpu_load_autotune_results_from: str xla_gpu_dump_autotune_logs_to: str + # TODO(b/352486192): Move this to `ExecutableBuildOptions`. + xla_use_shardy: bool class CompiledMemoryStats: generated_code_size_in_bytes: int From 359ea704c4ad0cbfc5ee59a4487f1f0cc2db4601 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 23 Jul 2024 06:40:16 -0700 Subject: [PATCH 087/376] [IFRT] Introduce pass to merge reshards with the same src and dst. This optimizes the reshards to use a single RPC call instead of multiple RPC calls. It currently merges only when the src is func BlockArg or an ifrt Op with `outputs` of type IfrtArrayType, and for any destination. PiperOrigin-RevId: 655145015 --- xla/python/ifrt/ir/tests/BUILD | 1 + .../ifrt/ir/tests/ifrt_merge_reshards.mlir | 77 ++++++ xla/python/ifrt/ir/transforms/BUILD | 1 + .../ir/transforms/ifrt_merge_reshards_pass.cc | 221 ++++++++++++++++++ xla/python/ifrt/ir/transforms/passes.h | 4 + xla/python/ifrt/ir/transforms/passes.td | 32 +++ 6 files changed, 336 insertions(+) create mode 100644 xla/python/ifrt/ir/tests/ifrt_merge_reshards.mlir create mode 100644 xla/python/ifrt/ir/transforms/ifrt_merge_reshards_pass.cc diff --git a/xla/python/ifrt/ir/tests/BUILD b/xla/python/ifrt/ir/tests/BUILD index 7b727f0f6a02a1..f68d8b09142df1 100644 --- a/xla/python/ifrt/ir/tests/BUILD +++ b/xla/python/ifrt/ir/tests/BUILD @@ -11,6 +11,7 @@ lit_test_suite( srcs = enforce_glob( [ "ifrt_duplicated_callee_elimination.mlir", + "ifrt_merge_reshards.mlir", "ifrt_verify_donation.mlir", "ifrt_verify_sharding_specified.mlir", "spmd_expansion.mlir", diff --git a/xla/python/ifrt/ir/tests/ifrt_merge_reshards.mlir b/xla/python/ifrt/ir/tests/ifrt_merge_reshards.mlir new file mode 100644 index 00000000000000..4f8f0e20bc60cc --- /dev/null +++ b/xla/python/ifrt/ir/tests/ifrt_merge_reshards.mlir @@ -0,0 +1,77 @@ +// RUN: ifrt-opt %s -ifrt-merge-reshards | FileCheck %s + +#sharding = #ifrt.sharding_param<2 to [0] on 2> +!array0 = !ifrt.array, #sharding, [0,1]> +!array1 = !ifrt.array, #sharding, [2,3]> + +// CHECK-LABEL: @merge_reshards_of_call_results +func.func @merge_reshards_of_call_results(%arg0: !array0, %arg1: !array0) + -> (!array1, !array1) attributes {ifrt.function} { +// CHECK-NEXT: %[[CALL:.*]]:2, %{{.*}} = ifrt.Call @identity(%arg0, %arg1) +// CHECK-NEXT: %[[MERGED:.*]]:2, %{{.*}} = ifrt.Reshard(%[[CALL]]#0, %[[CALL]]#1) +// CHECK-NEXT: return %[[MERGED]]#0, %[[MERGED]]#1 + %0:2, %ctrl_0 = ifrt.Call @identity(%arg0, %arg1) on devices [0,1] + : (!array0, !array0) -> (!array0, !array0) + %1, %ctrl_1 = ifrt.Reshard(%0#0) : (!array0) -> !array1 + %2, %ctrl_2 = ifrt.Reshard(%0#1) : (!array0) -> !array1 + return %1, %2 : !array1, !array1 +} + +// CHECK-LABEL: @merge_reshards_of_func_args +func.func @merge_reshards_of_func_args(%arg0: !array0, %arg1: !array0) + -> (!array1, !array1) attributes {ifrt.function} { +// CHECK-NEXT: %[[MERGED:.*]]:2, %{{.*}} = ifrt.Reshard(%arg0, %arg1) +// CHECK-NEXT: return %[[MERGED]]#0, %[[MERGED]]#1 + %1, %ctrl_1 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + %2, %ctrl_2 = ifrt.Reshard(%arg1) : (!array0) -> !array1 + return %1, %2 : !array1, !array1 +} + +// CHECK-LABEL: @merge_reshards_for_same_devices_only +func.func @merge_reshards_for_same_devices_only( + %arg0: !array0, %arg1: !array0, %arg2: !array0, %arg3: !array0) + -> (!array1, !array1, !array0, !array0) attributes {ifrt.function} { +// CHECK-NEXT: %[[MERGED1:.*]]:2, %{{.*}} = ifrt.Reshard(%arg0, %arg1) +// CHECK-NEXT: %[[MERGED2:.*]]:2, %{{.*}} = ifrt.Reshard(%arg2, %arg3) +// CHECK-NEXT: return %[[MERGED1]]#0, %[[MERGED1]]#1, %[[MERGED2]]#0, %[[MERGED2]]#1 + %1, %ctrl_1 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + %2, %ctrl_2 = ifrt.Reshard(%arg1) : (!array0) -> !array1 + %3, %ctrl_3 = ifrt.Reshard(%arg2) : (!array0) -> !array0 + %4, %ctrl_4 = ifrt.Reshard(%arg3) : (!array0) -> !array0 + return %1, %2, %3, %4 : !array1, !array1, !array0, !array0 +} + +// CHECK-LABEL: @merge_reshards_for_same_donated_only +func.func @merge_reshards_for_same_donated_only( + %arg0: !array0, %arg1: !array0, %arg2: !array0, %arg3: !array0) + -> (!array1, !array1, !array1, !array1) attributes {ifrt.function} { +// CHECK-NEXT: %[[MERGED1:.*]]:2, %{{.*}} = ifrt.Reshard(%arg0, %arg1) {donated = true} : +// CHECK-NEXT: %[[MERGED2:.*]]:2, %{{.*}} = ifrt.Reshard(%arg2, %arg3) +// CHECK-NOT: {donated = true} +// CHECK-NEXT: return %[[MERGED1]]#0, %[[MERGED1]]#1, %[[MERGED2]]#0, %[[MERGED2]]#1 + %1, %ctrl_1 = ifrt.Reshard(%arg0) {donated = true} : (!array0) -> !array1 + %2, %ctrl_2 = ifrt.Reshard(%arg1) {donated = true} : (!array0) -> !array1 + %3, %ctrl_3 = ifrt.Reshard(%arg2) {donated = false} : (!array0) -> !array1 + %4, %ctrl_4 = ifrt.Reshard(%arg3) : (!array0) -> !array1 + return %1, %2, %3, %4 : !array1, !array1, !array1, !array1 +} + +// CHECK-LABEL: @dont_merge_if_any_control_dependencies +func.func @dont_merge_if_any_control_dependencies( + %arg0: !array0, %arg1: !array0) + -> (!array1, !array1) attributes {ifrt.function} { +// CHECK-NEXT: %[[CALL:.*]]:2, %[[CTRL:.*]] = ifrt.Call @identity(%arg0, %arg1) +// CHECK-NEXT: %[[R1:.*]], %{{.*}} = ifrt.Reshard(%[[CALL]]#0) after %[[CTRL]] +// CHECK-NEXT: %[[R2:.*]], %{{.*}} = ifrt.Reshard(%[[CALL]]#1) +// CHECK-NEXT: return %[[R1]], %[[R2]] + %0:2, %ctrl_0 = ifrt.Call @identity(%arg0, %arg1) on devices [0,1] + : (!array0, !array0) -> (!array0, !array0) + %1, %ctrl_1 = ifrt.Reshard(%0#0) after %ctrl_0 : (!array0) -> !array1 + %2, %ctrl_2 = ifrt.Reshard(%0#1) : (!array0) -> !array1 + return %1, %2 : !array1, !array1 +} + +func.func private @identity(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) + -> (tensor<2xi32>, tensor<2xi32>) { + return %arg0, %arg1 : tensor<2xi32>, tensor<2xi32> +} diff --git a/xla/python/ifrt/ir/transforms/BUILD b/xla/python/ifrt/ir/transforms/BUILD index 53c70df7f6b56f..ccd1919e3ccf5d 100644 --- a/xla/python/ifrt/ir/transforms/BUILD +++ b/xla/python/ifrt/ir/transforms/BUILD @@ -30,6 +30,7 @@ cc_library( name = "passes", srcs = [ "ifrt_duplicated_callee_elimination_pass.cc", + "ifrt_merge_reshards_pass.cc", "ifrt_verify_donation_pass.cc", "ifrt_verify_sharding_specified_pass.cc", "spmd_expandable_interface_verification_pass.cc", diff --git a/xla/python/ifrt/ir/transforms/ifrt_merge_reshards_pass.cc b/xla/python/ifrt/ir/transforms/ifrt_merge_reshards_pass.cc new file mode 100644 index 00000000000000..a3a8a4952dd50f --- /dev/null +++ b/xla/python/ifrt/ir/transforms/ifrt_merge_reshards_pass.cc @@ -0,0 +1,221 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/log/check.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTMERGERESHARDSPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +class IfrtMergeReshardsPass + : public impl::IfrtMergeReshardsPassBase { + public: + void runOnOperation() override; +}; + +// Merges reshards on `source_values` which flow into the same +// destination. We merge only if the reshard: +// - has only one input and output. I.e. it isn't already merged. +// - has no input control dependencies. +// - has the same `donation` setting. +// +// `source_values` are expected to be of type IfrtArrayType on the same devices, +// and be OpResults from the same ops, or BlockArgs in the same block. +// +// We defer erasing the op until the end of the pass, to avoid invalidating the +// iterator. +void MergeReshardsIgnoringControlDependencies( + mlir::ValueRange source_values, std::vector& ops_to_erase, + mlir::RewriterBase& rewriter) { + // We group reshards by {first_user, devices, donated, src_memory_kind, + // dst_memory_kind}. We need to group by: + // - first_user because we need to pick some destination. + // - devices as well, because certain users can have multiple devices, e.g. + // func.return. + // - donated because the donation is all-or-nothing. + // - src and dst memory kind because we can't merge reshards that change + // memory kind. + llvm::DenseMap, + llvm::SmallVector> + user_device_donate_tuple_to_reshards; + + // Group reshards by their first user. + for (mlir::Value value : source_values) { + CHECK(mlir::isa(value.getType())); + + for (mlir::Operation* user : value.getUsers()) { + auto reshard_op = mlir::dyn_cast(user); + if (!reshard_op || reshard_op.getOutputs().size() != 1 || + reshard_op->use_empty() || !reshard_op.getControlInputs().empty()) { + continue; + } + + // This could potentially be very expensive as `isBeforeInBlock` is + // average O(1) but worst case O(n). + mlir::Operation* first_reshard_user = *llvm::min_element( + reshard_op->getUsers(), [](mlir::Operation* a, mlir::Operation* b) { + return a->isBeforeInBlock(b); + }); + + auto output_type = + mlir::cast(reshard_op.getOutputs().front().getType()); + user_device_donate_tuple_to_reshards + [{first_reshard_user, output_type.getDevicesAttr(), + // We can't hash by the bool itself, and `donated` is a optional + // attr, so false can be represented by nullptr or BoolAttr(false). + // So we explicitly convert to BoolAttr. + rewriter.getBoolAttr(reshard_op.getDonated()), + mlir::cast(reshard_op.getInputs().front().getType()) + .getMemoryKindAttr(), + output_type.getMemoryKindAttr()}] + .push_back(reshard_op); + } + } + + // Rewrite each group of reshards. + for (auto& [_, reshards] : user_device_donate_tuple_to_reshards) { + // Create a new reshard op that takes all the inputs of the reshards. + llvm::SmallVector inputs; + llvm::SmallVector output_types; + llvm::SmallVector locs; + inputs.reserve(reshards.size()); + output_types.reserve(reshards.size()); + locs.reserve(reshards.size()); + + for (ReshardOp reshard : reshards) { + CHECK_EQ(reshard.getInputs().size(), 1); + CHECK_EQ(reshard.getOutputs().size(), 1); + inputs.push_back(reshard.getInputs()[0]); + output_types.push_back(reshard.getOutputs()[0].getType()); + locs.push_back(reshard.getLoc()); + } + + // Insert the new reshard op just before one of the reshards, to + // minimize reordering reshards. + rewriter.setInsertionPoint(reshards.front()); + auto merged_reshard = + rewriter.create(rewriter.getFusedLoc(locs), + /*outputs=*/output_types, + /*control_output=*/ + IfrtControlType::get(rewriter.getContext()), + /*inputs=*/inputs, + /*donated=*/reshards.front().getDonated(), + /*control_inputs=*/mlir::ValueRange()); + + // Replace the original reshards with the new merged reshard. + for (auto [index, reshard] : llvm::enumerate(reshards)) { + rewriter.replaceAllUsesWith(reshard.getOutputs()[0], + merged_reshard.getOutputs()[index]); + rewriter.replaceAllUsesWith(reshard.getControlOutput(), + merged_reshard.getControlOutput()); + ops_to_erase.push_back(reshard); + } + } +} + +template +bool MergeReshardsIgnoringControlDependencies( + mlir::Operation* op, std::vector& ops_to_erase, + mlir::RewriterBase& rewriter) { + if (auto casted = mlir::dyn_cast(op)) { + MergeReshardsIgnoringControlDependencies(casted.getOutputs(), ops_to_erase, + rewriter); + return true; + } + return false; +} + +template +bool MergeReshardsIgnoringControlDependencies( + mlir::Operation* op, std::vector& ops_to_erase, + mlir::RewriterBase& rewriter) { + return MergeReshardsIgnoringControlDependencies(op, ops_to_erase, + rewriter) || + MergeReshardsIgnoringControlDependencies( + op, ops_to_erase, rewriter); +} + +void IfrtMergeReshardsPass::runOnOperation() { + mlir::func::FuncOp func_op = getOperation(); + mlir::IRRewriter rewriter(func_op->getContext()); + std::vector ops_to_erase; + + // We only need to run this pass on IFRT functions. + if (!func_op->hasAttr(kIfrtFunctionAttrName) && + !func_op->hasAttr(kIfrtReshardFunctionAttrName)) { + return; + } + + // Handle func block args. + { + llvm::DenseMap> + devices_to_args; + for (mlir::Value arg : func_op.getArguments()) { + if (auto array_type = mlir::dyn_cast(arg.getType())) { + devices_to_args[array_type.getDevicesAttr()].push_back(arg); + } + } + + for (auto& [_, args] : devices_to_args) { + MergeReshardsIgnoringControlDependencies(args, ops_to_erase, rewriter); + } + } + + // Handle ops in the IFRT function body. + func_op.getFunctionBody().walk([&](mlir::Operation* op) { + MergeReshardsIgnoringControlDependencies< + CallOp, CallLoadedExecutableOp, ReshardOp, CopyArraysOp, RemapArraysOp>( + op, ops_to_erase, rewriter); + }); + + for (mlir::Operation* op : ops_to_erase) { + rewriter.eraseOp(op); + } +} + +} // namespace + +std::unique_ptr> +CreateIfrtMergeReshardsPass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/ir/transforms/passes.h b/xla/python/ifrt/ir/transforms/passes.h index 190735960cb595..a2cd1748a6c3b0 100644 --- a/xla/python/ifrt/ir/transforms/passes.h +++ b/xla/python/ifrt/ir/transforms/passes.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -39,6 +40,9 @@ std::unique_ptr> CreateSpmdExpansionPass(); std::unique_ptr> CreateIfrtDuplicatedCalleeEliminationPass(); +std::unique_ptr> +CreateIfrtMergeReshardsPass(); + std::unique_ptr> CreateIfrtVerifyDonationPass(); diff --git a/xla/python/ifrt/ir/transforms/passes.td b/xla/python/ifrt/ir/transforms/passes.td index 839cb50e21aff9..d299c6b4786425 100644 --- a/xla/python/ifrt/ir/transforms/passes.td +++ b/xla/python/ifrt/ir/transforms/passes.td @@ -107,6 +107,38 @@ them. The duplicated callee `FuncOp` will not be removed. let constructor = "CreateIfrtDuplicatedCalleeEliminationPass()"; } +def IfrtMergeReshardsPass : + Pass<"ifrt-merge-reshards", "mlir::func::FuncOp"> { + let summary = "Merge reshards for each (src, dst) pair into a single reshard."; + let description = [{ +Merges reshards from a source op which flow into the same destination op, +ignoring control dependencies. + +E.g. + +```mlir +%c:4, %ctrl_0 = ifrt.Call ... +%r0, %ctrl_1 = ifrt.Reshard %c#0 ... +%r1, %ctrl_2 = ifrt.Reshard %c#1 ... +%d:4, %ctrl_3 = ifrt.Call (%r0, %r1) +``` + +will be replaced by: + +```mlir +%c:4, %ctrl_0 = ifrt.Call ... +%r:2, %ctrl_1 = ifrt.Reshard (%c#0, %c#1) +%d:4, %ctrl_2 = ifrt.Call (%r#0, %r#1) +``` + +Currently this handles the case where the source is the Func BlockArgs or +the outputs of a ifrt.Call, ifrt.CallLoadedExecutable, ifrt.Reshard, +ifrt.CopyArrays, and ifrt.RemapArrays. + }]; + + let constructor = "CreateIfrtMergeReshardsPass()"; +} + def IfrtVerifyDonationPass : Pass<"ifrt-verify-donation", "mlir::ModuleOp"> { let summary = "Verify that `!ifrt.array` are not donated more than once."; let description = [{ From 5ef10e53fdfc0a7069fb4c5532051df0086e0fd3 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 23 Jul 2024 06:41:44 -0700 Subject: [PATCH 088/376] PR #15222: [GPU] Fix cuDNN workspace test condition. Imported from GitHub PR https://github.com/openxla/xla/pull/15222 Copybara import of the project: -- e57258f605e6df6bf61ac55632e939b5badd6c22 by Ilia Sergachev : [GPU] Fix cuDNN workspace test condition. Merging this change closes #15222 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15222 from openxla:fix_cudnn_test e57258f605e6df6bf61ac55632e939b5badd6c22 PiperOrigin-RevId: 655145324 --- xla/service/gpu/fusions/cudnn_test.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xla/service/gpu/fusions/cudnn_test.cc b/xla/service/gpu/fusions/cudnn_test.cc index a8d5471784c178..9e9e1ce7560500 100644 --- a/xla/service/gpu/fusions/cudnn_test.cc +++ b/xla/service/gpu/fusions/cudnn_test.cc @@ -203,6 +203,9 @@ ENTRY e { TEST_F(CuDnnFusionExecutionTest, CuDnnFusionCompilerDoesNotFailOnDependentFusions) { + if (!IsAtLeastCuDnn91()) { + GTEST_SKIP() << "This test case requests a workspace only with cuDNN 9.1+."; + } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( c1 { From a02433bdc5e6dbe68446d6c95c4959b5e1715cf9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 23 Jul 2024 06:46:59 -0700 Subject: [PATCH 089/376] Integrate LLVM at llvm/llvm-project@acc159aea1e6 Updates LLVM usage to match [acc159aea1e6](https://github.com/llvm/llvm-project/commit/acc159aea1e6) PiperOrigin-RevId: 655146366 --- third_party/llvm/generated.patch | 11 ++++ third_party/llvm/workspace.bzl | 4 +- third_party/shardy/workspace.bzl | 4 +- third_party/stablehlo/temporary.patch | 57 +++++++++++++++++++ .../triton/llvm_integration/series.bzl | 1 + .../tsl/third_party/llvm/generated.patch | 11 ++++ .../tsl/third_party/llvm/workspace.bzl | 4 +- xla/mlir_hlo/mhlo/utils/type_conversion.cc | 20 +++++-- 8 files changed, 100 insertions(+), 12 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index ed3d58f027f90b..7af6db90fd2b4f 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1343,3 +1343,14 @@ diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/utils/MPFR "//libc:__support_macros_properties_types", "//libc:hdr_math_macros", "//libc/test/UnitTest:LibcUnitTest", +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +@@ -2900,6 +2900,7 @@ + ":IR", + ":LoopLikeInterface", + ":SCFDialect", ++ ":SCFToControlFlow", + ":SCFTransformOpsIncGen", + ":SCFTransforms", + ":SCFUtils", diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 6c8da928bb4d1a..a108e965dd0086 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "dd7d81ea49bf39e1d69bbb84bd3f31bd95519369" - LLVM_SHA256 = "fbd43ef20f4209b0619e209e48c431f76008917714a8c5336063e1ff51d8d084" + LLVM_COMMIT = "acc159aea1e641e3694ab8fe5faa231788077011" + LLVM_SHA256 = "ff2d0c2d9dd22eb39b3d135bcf0cf91008b395de797f543e32790df372945d13" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 55e16fed709356..ef740f479ad0f4 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "58fc775e94b0e7b0f127848e151f5c0dc4c64435" - SHARDY_SHA256 = "446714f551b9df42b99c6afc913b7078b69c6985f300116f9bcf2279ab4cb623" + SHARDY_COMMIT = "05a83632728cbdf172bb92e3fd644487b74275f6" + SHARDY_SHA256 = "d89ae97cdfdbc5a192b90e7028a3b06873d2a8db5ffb092c2cd0bd4e30b29806" tf_http_archive( name = "shardy", diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b137891791fe9..3cae7dc292dc85 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,58 @@ +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +--- stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp ++++ stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +@@ -66,17 +66,25 @@ + ->getResult(0); + } + +-std::optional scalarToTensor(OpBuilder &builder, Type /*type*/, ++std::optional scalarToTensor(OpBuilder& builder, Type type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); +- if (llvm::isa(inputs.front().getType())) { ++ if (mlir::isa(inputs.front().getType())) { + return std::nullopt; + } +- return builder +- .create( +- loc, RankedTensorType::get({}, inputs.front().getType()), +- inputs.front()) +- .getResult(); ++ Value result = ++ builder ++ .create( ++ loc, RankedTensorType::get({}, inputs.front().getType()), ++ inputs.front()) ++ .getResult(); ++ // Convert to a signed integer if necessary. ++ Type elementType = mlir::getElementTypeOrSelf(type); ++ if (elementType.isInteger() && !elementType.isSignlessInteger()) { ++ result = builder.create(loc, type, result) ++ ->getResult(0); ++ } ++ return result; + } + + } // namespace +diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp +--- stablehlo/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp ++++ stablehlo/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp +@@ -1270,12 +1270,13 @@ + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResultTypes, op->getAttrs(), op->getSuccessors()); + for (Region ®ion : op->getRegions()) { +- Region &newRegion = *state.addRegion(); +- rewriter.inlineRegionBefore(region, newRegion, newRegion.begin()); +- if (failed( +- rewriter.convertRegionTypes(&newRegion, *getTypeConverter()))) { ++ auto newRegion = std::make_unique(op); ++ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); ++ if (failed(rewriter.convertRegionTypes(newRegion.get(), ++ *getTypeConverter()))) { + return failure(); + } ++ state.addRegion(std::move(newRegion)); + } + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp); diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 656b9c894904d8..7b438990166a30 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,5 +8,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ + "//third_party/triton/llvm_integration:cl654795065.patch", # Add new patches just above this line ] diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index ed3d58f027f90b..7af6db90fd2b4f 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1343,3 +1343,14 @@ diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/utils/MPFR "//libc:__support_macros_properties_types", "//libc:hdr_math_macros", "//libc/test/UnitTest:LibcUnitTest", +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +@@ -2900,6 +2900,7 @@ + ":IR", + ":LoopLikeInterface", + ":SCFDialect", ++ ":SCFToControlFlow", + ":SCFTransformOpsIncGen", + ":SCFTransforms", + ":SCFUtils", diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 6c8da928bb4d1a..a108e965dd0086 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "dd7d81ea49bf39e1d69bbb84bd3f31bd95519369" - LLVM_SHA256 = "fbd43ef20f4209b0619e209e48c431f76008917714a8c5336063e1ff51d8d084" + LLVM_COMMIT = "acc159aea1e641e3694ab8fe5faa231788077011" + LLVM_SHA256 = "ff2d0c2d9dd22eb39b3d135bcf0cf91008b395de797f543e32790df372945d13" tf_http_archive( name = name, diff --git a/xla/mlir_hlo/mhlo/utils/type_conversion.cc b/xla/mlir_hlo/mhlo/utils/type_conversion.cc index 50a8d01313ec6f..4ff1bf56fde53f 100644 --- a/xla/mlir_hlo/mhlo/utils/type_conversion.cc +++ b/xla/mlir_hlo/mhlo/utils/type_conversion.cc @@ -74,17 +74,25 @@ std::optional materializeCastToIllegal(OpBuilder& builder, Type type, ->getResult(0); } -std::optional scalarToTensor(OpBuilder& builder, Type /*type*/, +std::optional scalarToTensor(OpBuilder& builder, Type type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); if (mlir::isa(inputs.front().getType())) { return std::nullopt; } - return builder - .create( - loc, RankedTensorType::get({}, inputs.front().getType()), - inputs.front()) - .getResult(); + Value result = + builder + .create( + loc, RankedTensorType::get({}, inputs.front().getType()), + inputs.front()) + .getResult(); + // Convert to a signed integer if necessary. + Type elementType = mlir::getElementTypeOrSelf(type); + if (elementType.isInteger() && !elementType.isSignlessInteger()) { + result = builder.create(loc, type, result) + ->getResult(0); + } + return result; } } // namespace From fe155d3cb0159eb79aee5555896d9077bb68565c Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Tue, 23 Jul 2024 11:08:33 -0700 Subject: [PATCH 090/376] Add missing patch file from LLVM integration PiperOrigin-RevId: 655231244 --- .../triton/llvm_integration/cl654795065.patch | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 third_party/triton/llvm_integration/cl654795065.patch diff --git a/third_party/triton/llvm_integration/cl654795065.patch b/third_party/triton/llvm_integration/cl654795065.patch new file mode 100644 index 00000000000000..19ac00d2cdb637 --- /dev/null +++ b/third_party/triton/llvm_integration/cl654795065.patch @@ -0,0 +1,15 @@ +diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp ++++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeCon + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> std::optional { +- llvm_unreachable("Argument rematerialization should not happen in Triton " +- "-> TritonGPU conversion"); ++ // TODO(b/354860562): reenable or remove. ++ // llvm_unreachable("Argument rematerialization should not happen in Triton " ++ // "-> TritonGPU conversion"); + return std::nullopt; + }); + From 7e197e711789188180552c4e367be165176f1b1d Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 23 Jul 2024 11:53:48 -0700 Subject: [PATCH 091/376] [xla:cpu] Support for up to 16 sorted inputs + enable more jax/lax tests for XLA CPU thunks PiperOrigin-RevId: 655249641 --- xla/service/cpu/runtime/sort_thunk.cc | 69 ++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 6 deletions(-) diff --git a/xla/service/cpu/runtime/sort_thunk.cc b/xla/service/cpu/runtime/sort_thunk.cc index 959b096b8ba18c..c8adf958f2a6e7 100644 --- a/xla/service/cpu/runtime/sort_thunk.cc +++ b/xla/service/cpu/runtime/sort_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -49,6 +50,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -403,12 +405,65 @@ static absl::Status SortInplace(absl::Span data, int64_t inner_idx = i % sort_dims.inner_dim_size; int64_t offset = inner_idx + (i - inner_idx) * sort_dims.sort_dim_size; - if (data.size() == 1) { - SortInplace<1>(sort_dims, offset, data, shapes, is_stable, less_than); - } else if (data.size() == 2) { - SortInplace<2>(sort_dims, offset, data, shapes, is_stable, less_than); - } else { - return Internal("Unsupported number of sorted inputs: %d", data.size()); + auto sort = [&](auto num_inputs) { + SortInplace(sort_dims, offset, data, shapes, + is_stable, less_than); + }; + + // TODO(ezhulenev): We can replace statically known number of sorted inputs + // with a dynamic value, however statically known number of inputs allows + // compiler to generate better code. Benchmark if it really matters. + switch (data.size()) { + case 1: + sort(std::integral_constant{}); + break; + case 2: + sort(std::integral_constant{}); + break; + case 3: + sort(std::integral_constant{}); + break; + case 4: + sort(std::integral_constant{}); + break; + case 5: + sort(std::integral_constant{}); + break; + case 6: + sort(std::integral_constant{}); + break; + case 7: + sort(std::integral_constant{}); + break; + case 8: + sort(std::integral_constant{}); + break; + case 9: + sort(std::integral_constant{}); + break; + case 10: + sort(std::integral_constant{}); + break; + case 11: + sort(std::integral_constant{}); + break; + case 12: + sort(std::integral_constant{}); + break; + case 13: + sort(std::integral_constant{}); + break; + case 14: + sort(std::integral_constant{}); + break; + case 15: + sort(std::integral_constant{}); + break; + case 16: + sort(std::integral_constant{}); + break; + default: + return Internal("Unsupported number of sorted inputs: %d", data.size()); } } @@ -417,6 +472,8 @@ static absl::Status SortInplace(absl::Span data, tsl::AsyncValueRef SortThunk::Execute( const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + VLOG(3) << absl::StreamFormat( "Sort %d inputs along dimension %d (is_stable=%v)", inputs_.size(), dimension_, is_stable_); From 0f6b072ac177cf6db0bcc9fa63bfa5258e8363ab Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 23 Jul 2024 11:54:14 -0700 Subject: [PATCH 092/376] [xla:cpu] Make cpu_external_constants_test more robust to IR names PiperOrigin-RevId: 655249801 --- .../cpu/tests/cpu_external_constants_test.cc | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/xla/service/cpu/tests/cpu_external_constants_test.cc b/xla/service/cpu/tests/cpu_external_constants_test.cc index 20bcc3f973c2c8..b2147b67ed3cf2 100644 --- a/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include @@ -27,9 +28,9 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" -namespace xla { -namespace cpu { +namespace xla::cpu { namespace { + class CpuExternalConstantsTest : public CpuCodegenTest { public: void TestWithArray(int64_t rows, int64_t cols, @@ -57,21 +58,12 @@ class CpuExternalConstantsTest : public CpuCodegenTest { } }; -TEST_F(CpuExternalConstantsTest, Basic) { - TestWithArray(/*rows=*/1024, /*cols=*/1024, R"( -CHECK-NOT: @constant_global_0 = external unnamed_addr constant [1024 x [1024 x float]], align 16 -CHECK: @constant = private unnamed_addr constant [4194304 x i8] {{.*}}, align 16 -)"); -} - -TEST_F(CpuExternalConstantsTest, BasicNegative) { - // The constant array in this test case is small enough that there is no need - // to externalize it. +TEST_F(CpuExternalConstantsTest, DoNotExternalizeConstants) { TestWithArray(/*rows=*/4, /*cols=*/4, R"( -CHECK-NOT: @constant_global_0 = external unnamed_addr constant [16 x float] -CHECK: @constant = private unnamed_addr constant [64 x i8] {{.*}}, align 16 +CHECK-NOT: external unnamed_addr constant [16 x float] +CHECK: @[[CST:.+]] = private unnamed_addr constant [64 x i8] {{.*}}, align 16 )"); } + } // namespace -} // namespace cpu -} // namespace xla +} // namespace xla::cpu From 929f0b71fa7bdc8861939e177ec47114812031f2 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 23 Jul 2024 11:57:46 -0700 Subject: [PATCH 093/376] [xla:cpu] Rename thunks execute event to avoid confusion We already have another execute_event in a scope. PiperOrigin-RevId: 655250959 --- xla/pjrt/cpu/cpu_client.cc | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index ad9f26eed760b3..d7c7575ec06cbc 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -1611,12 +1611,14 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( &collective_params, &custom_call_execute_params}; - auto execute_event = cpu_executable->thunks().Execute(execute_params); + auto thunks_execute_event = + cpu_executable->thunks().Execute(execute_params); tsl::profiler::TraceMe trace( "ThunkExecutor::Execute (wait for completion)"); - tsl::BlockUntilReady(execute_event); - if (execute_event.IsError()) return execute_event.GetError(); + tsl::BlockUntilReady(thunks_execute_event); + if (thunks_execute_event.IsError()) + return thunks_execute_event.GetError(); } else { return Internal("CpuExecutable has no compute function or thunks."); @@ -1748,14 +1750,15 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( &*collective_params, &*custom_call_params}; - auto execute_event = + auto thunks_execute_event = cpu_executable->thunks().Execute(execute_params); tsl::profiler::TraceMe trace( "ThunkExecutor::Execute (wait for completion)"); - tsl::BlockUntilReady(execute_event); - status = execute_event.IsError() ? execute_event.GetError() - : absl::OkStatus(); + tsl::BlockUntilReady(thunks_execute_event); + status = thunks_execute_event.IsError() + ? thunks_execute_event.GetError() + : absl::OkStatus(); } else { status = collective_params.status(); } From 5c4dd0d5db11e916d24f8e9bb74085a9afa512c7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 23 Jul 2024 12:08:10 -0700 Subject: [PATCH 094/376] [XLA] Created new AlgebraicSimplifierOptions field for separately controlling rewrites previously all controlled by use_associative_reordering. use_associative_reordering is responsible for: - dot(dot(a, b), c) to dot(a, dot(b, c)) - dot(pad(a), b) to dot(a, slice(b)) and similar for broadcast/reduce and reverse/reverse raise_slice_and_reduce_through_dot is responsible for: - slice(dot(a, b)) to dot(slice(a), slice(b)) - reduce(dot(a, b)) to dot(reduce(a), reduce(b)) PiperOrigin-RevId: 655255133 --- xla/service/algebraic_simplifier.cc | 6 +++--- xla/service/algebraic_simplifier.h | 25 ++++++++++++++++++++++-- xla/service/algebraic_simplifier_test.cc | 10 ++++------ 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index 2ef7b352a6afe7..473a5c948b2514 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -6475,7 +6475,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { // Try to reorder slice of dot to the operand it comes from if (!options_.is_layout_sensitive() && - options_.use_associative_reordering() && + options_.raise_slice_and_reduce_through_dot() && slice->operand(0)->opcode() == HloOpcode::kDot) { // Unpack the dot operands HloDotInstruction* dot = Cast(slice->mutable_operand(0)); @@ -7436,7 +7436,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } // Try to reorder reduce(dot(A, B)) to dot(A, reduce(B)) - if (options_.use_associative_reordering()) { + if (options_.raise_slice_and_reduce_through_dot()) { HloInstruction *a, *b; // Reordering does not seem possible if the dot has batch dimensions. We // also need the reduction operation to be add, and the reduce to have an @@ -7530,7 +7530,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // Only reorder if it would result in sufficiently fewer flops if (old_flops / static_cast(new_flops) > - options_.associative_reordering_threshold()) { + options_.raise_slice_and_reduce_through_dot_threshold()) { VLOG(10) << "Reordering reduce into dot operands"; return ReplaceInstruction(reduce, new_dot); } diff --git a/xla/service/algebraic_simplifier.h b/xla/service/algebraic_simplifier.h index 3340d248ed43cf..7261872d40ceda 100644 --- a/xla/service/algebraic_simplifier.h +++ b/xla/service/algebraic_simplifier.h @@ -117,6 +117,25 @@ class AlgebraicSimplifierOptions { return use_convert_constant_folding_; } + void set_raise_slice_and_reduce_through_dot( + bool raise_slice_and_reduce_through_dot) { + raise_slice_and_reduce_through_dot_ = raise_slice_and_reduce_through_dot; + } + + bool raise_slice_and_reduce_through_dot() const { + return raise_slice_and_reduce_through_dot_; + } + + void set_raise_slice_and_reduce_through_dot_threshold( + double raise_slice_and_reduce_through_dot_threshold) { + raise_slice_and_reduce_through_dot_threshold_ = + raise_slice_and_reduce_through_dot_threshold; + } + + double raise_slice_and_reduce_through_dot_threshold() const { + return raise_slice_and_reduce_through_dot_threshold_; + } + // Enable dot simplification on platforms where it is profitable. void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { enable_dot_strength_reduction_ = enable_dot_strength_reduction; @@ -300,10 +319,12 @@ class AlgebraicSimplifierOptions { int64_t very_small_gather_size_{4}; bool minmax_propagate_nan_{true}; bool enable_unconditional_reduce_of_concat_replacement_{true}; - bool use_associative_reordering_{false}; - bool use_convert_constant_folding_{false}; bool executing_on_cpu_{false}; + bool use_associative_reordering_{false}; double associative_reordering_threshold_{2.0}; + bool raise_slice_and_reduce_through_dot_{false}; + double raise_slice_and_reduce_through_dot_threshold_{2.0}; + bool use_convert_constant_folding_{false}; Metadata metadata_; }; diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index c3dadde7fe93a6..2880c77ca0a775 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -6594,8 +6594,7 @@ TEST_F(AlgebraicSimplifierTest, ReduceDotReorder) { ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifierOptions options; - options.set_use_associative_reordering(true); - options.set_associative_reordering_threshold(0); + options.set_raise_slice_and_reduce_through_dot(true); AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).value()); ASSERT_THAT( @@ -6621,8 +6620,7 @@ TEST_F(AlgebraicSimplifierTest, SliceDotReorder) { ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifierOptions options; - options.set_use_associative_reordering(true); - options.set_associative_reordering_threshold(0); + options.set_raise_slice_and_reduce_through_dot(true); AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).value()); ASSERT_THAT(module->entry_computation()->root_instruction(), @@ -6646,7 +6644,7 @@ TEST_F(AlgebraicSimplifierTest, SliceDotReorderWithStrides) { ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifierOptions options; - options.set_use_associative_reordering(true); + options.set_raise_slice_and_reduce_through_dot(true); EXPECT_TRUE(AlgebraicSimplifier(options).Run(module.get()).value()); ASSERT_THAT( module->entry_computation()->root_instruction(), @@ -11341,7 +11339,7 @@ TEST_F(AlgebraicSimplifierTest, SparseDotMoveSliceToOperands) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); AlgebraicSimplifierOptions options; - options.set_use_associative_reordering(true); + options.set_raise_slice_and_reduce_through_dot(true); ASSERT_TRUE(AlgebraicSimplifier(options).Run(module.get()).value()); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(SparseDotMatcher(m::Slice(m::Parameter(0)), From b6d5a87950fd6f0ff039c495edfc1e6e359d8018 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 23 Jul 2024 12:49:55 -0700 Subject: [PATCH 095/376] internal copybara refactor change PiperOrigin-RevId: 655269008 --- xla/stream_executor/platform/default/initialize.h | 2 +- xla/stream_executor/platform/port.h | 2 +- xla/tests/hlo_test_base.h | 2 +- xla/tsl/framework/contraction/BUILD | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xla/stream_executor/platform/default/initialize.h b/xla/stream_executor/platform/default/initialize.h index cb951ed8b0611c..78b24977ac7ae3 100644 --- a/xla/stream_executor/platform/default/initialize.h +++ b/xla/stream_executor/platform/default/initialize.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "third_party/tensorflow/compiler/xla/stream_executor/platform/initialize.h" +// IWYU pragma: private, include "xla/stream_executor/platform/initialize.h" #ifndef XLA_STREAM_EXECUTOR_PLATFORM_DEFAULT_INITIALIZE_H_ #define XLA_STREAM_EXECUTOR_PLATFORM_DEFAULT_INITIALIZE_H_ diff --git a/xla/stream_executor/platform/port.h b/xla/stream_executor/platform/port.h index 6cd6654061501d..9561d7bf20cb15 100644 --- a/xla/stream_executor/platform/port.h +++ b/xla/stream_executor/platform/port.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "third_party/tensorflow/compiler/xla/stream_executor/stream_executor.h" +// IWYU pragma: private, include "xla/stream_executor/stream_executor.h" #ifndef XLA_STREAM_EXECUTOR_PLATFORM_PORT_H_ #define XLA_STREAM_EXECUTOR_PLATFORM_PORT_H_ diff --git a/xla/tests/hlo_test_base.h b/xla/tests/hlo_test_base.h index c9f88f237e9b00..9e90eac54cb576 100644 --- a/xla/tests/hlo_test_base.h +++ b/xla/tests/hlo_test_base.h @@ -65,7 +65,7 @@ namespace xla { // "gpu", // ], // deps = [ -// "//tensorflow/compiler/xla/tests:hlo_test_base", +// "//xla/tests:hlo_test_base", // ... // ], // ) diff --git a/xla/tsl/framework/contraction/BUILD b/xla/tsl/framework/contraction/BUILD index 354dfcf8716c34..4f5ae247e0b69c 100644 --- a/xla/tsl/framework/contraction/BUILD +++ b/xla/tsl/framework/contraction/BUILD @@ -61,7 +61,7 @@ config_setting( # to get the benefit of custom contraction kernel: # # #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) -# #include "third_party/tensorflow/compiler/xla/tsl/framework/contraction/eigen_contraction_kernel.h" +# #include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" # #endif # # We define a two-level target because if we just add From e56c9c1c51cef21cda50eddc45f2a115836a9358 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 23 Jul 2024 13:01:14 -0700 Subject: [PATCH 096/376] Remove unused `aws_support` and `hdfs_support` configs PiperOrigin-RevId: 655272587 --- .bazelrc | 8 -------- third_party/tsl/.bazelrc | 8 -------- 2 files changed, 16 deletions(-) diff --git a/.bazelrc b/.bazelrc index da4305b626e001..b94693e05efab8 100644 --- a/.bazelrc +++ b/.bazelrc @@ -42,9 +42,7 @@ # rocm: Build with AMD GPU support (rocm) # mkl: Enable full mkl support. # tensorrt: Enable Tensorrt support. -# noaws: Disable AWS S3 storage support # nogcp: Disable GCS support. -# nohdfs: Disable hadoop hdfs support. # nonccl: Disable nccl support. # # @@ -117,10 +115,6 @@ build --config=short_logs # TODO(mihaimaruseac): Document this option or remove if no longer needed build --config=v2 -# Disable AWS/HDFS support by default -build --define=no_aws_support=true -build --define=no_hdfs_support=true - # TF now has `cc_shared_library` targets, so it needs the experimental flag # TODO(rostam): Remove when `cc_shared_library` is enabled by default build --experimental_cc_shared_library @@ -296,9 +290,7 @@ build:sycl --define=tensorflow_mkldnn_contraction_kernel=0 build:sycl --repo_env TF_NEED_SYCL=1 # Options to disable default on features -build:noaws --define=no_aws_support=true build:nogcp --define=no_gcp_support=true -build:nohdfs --define=no_hdfs_support=true build:nonccl --define=no_nccl_support=true # Modular TF build options diff --git a/third_party/tsl/.bazelrc b/third_party/tsl/.bazelrc index da4305b626e001..b94693e05efab8 100644 --- a/third_party/tsl/.bazelrc +++ b/third_party/tsl/.bazelrc @@ -42,9 +42,7 @@ # rocm: Build with AMD GPU support (rocm) # mkl: Enable full mkl support. # tensorrt: Enable Tensorrt support. -# noaws: Disable AWS S3 storage support # nogcp: Disable GCS support. -# nohdfs: Disable hadoop hdfs support. # nonccl: Disable nccl support. # # @@ -117,10 +115,6 @@ build --config=short_logs # TODO(mihaimaruseac): Document this option or remove if no longer needed build --config=v2 -# Disable AWS/HDFS support by default -build --define=no_aws_support=true -build --define=no_hdfs_support=true - # TF now has `cc_shared_library` targets, so it needs the experimental flag # TODO(rostam): Remove when `cc_shared_library` is enabled by default build --experimental_cc_shared_library @@ -296,9 +290,7 @@ build:sycl --define=tensorflow_mkldnn_contraction_kernel=0 build:sycl --repo_env TF_NEED_SYCL=1 # Options to disable default on features -build:noaws --define=no_aws_support=true build:nogcp --define=no_gcp_support=true -build:nohdfs --define=no_hdfs_support=true build:nonccl --define=no_nccl_support=true # Modular TF build options From d17181b49de71b0fb0ff6236745d43d630c39401 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 23 Jul 2024 13:58:24 -0700 Subject: [PATCH 097/376] Move delay kernel handling for GpuTimer into cuda_executor.cc, as it's only supported for CUDA, not ROCm. PiperOrigin-RevId: 655293982 --- xla/stream_executor/cuda/BUILD | 33 ++++++++++++- xla/stream_executor/cuda/cuda_executor.cc | 49 +++++++++++++++++-- .../delay_kernel.h} | 10 ++-- .../delay_kernel_cuda.cu.cc} | 16 +----- xla/stream_executor/gpu/BUILD | 44 +---------------- xla/stream_executor/gpu/gpu_executor.h | 3 ++ xla/stream_executor/gpu/gpu_stream.h | 1 - xla/stream_executor/gpu/gpu_timer.cc | 39 +-------------- xla/stream_executor/gpu/gpu_timer.h | 11 ++--- .../gpu/gpu_timer_kernel_rocm.cc | 30 ------------ xla/stream_executor/rocm/rocm_executor.cc | 13 +++-- 11 files changed, 97 insertions(+), 152 deletions(-) rename xla/stream_executor/{gpu/gpu_timer_kernel.h => cuda/delay_kernel.h} (74%) rename xla/stream_executor/{gpu/gpu_timer_kernel_cuda.cu.cc => cuda/delay_kernel_cuda.cu.cc} (84%) delete mode 100644 xla/stream_executor/gpu/gpu_timer_kernel_rocm.cc diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 38475a360a04bc..a33da00c418d93 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -15,6 +15,10 @@ load( "//xla:xla.bzl", "xla_cc_test", ) +load( + "//xla/service/gpu:build_defs.bzl", + "gpu_kernel_library", +) load( "//xla/stream_executor:build_defs.bzl", "cuda_only_cc_library", @@ -391,6 +395,27 @@ cc_library( ], ) +gpu_kernel_library( + name = "delay_kernel_cuda", + srcs = [ + "delay_kernel.h", + "delay_kernel_cuda.cu.cc", + ], + tags = ["manual"], + visibility = internal_visibility([ + "//xla/stream_executor:__subpackages__", + ]), + deps = [ + "//xla/stream_executor", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_executor_header", + "//xla/stream_executor/gpu:gpu_semaphore", + "//xla/stream_executor/gpu:gpu_stream", + "@com_google_absl//absl/status:statusor", + ], +) + cuda_only_cc_library( name = "cudnn_plugin", srcs = ["cuda_dnn.cc"], @@ -731,7 +756,10 @@ cuda_only_cc_library( cuda_only_cc_library( name = "cuda_executor", - srcs = ["cuda_executor.cc"], + srcs = [ + "cuda_executor.cc", + "delay_kernel.h", + ], deps = [ ":cuda_collectives", # buildcleaner: keep ":cuda_diagnostics", @@ -758,6 +786,7 @@ cuda_only_cc_library( "//xla/stream_executor/gpu:gpu_event_header", "//xla/stream_executor/gpu:gpu_kernel_header", "//xla/stream_executor/gpu:gpu_runtime_header", + "//xla/stream_executor/gpu:gpu_semaphore", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", @@ -778,7 +807,7 @@ cuda_only_cc_library( "@tsl//tsl/platform:fingerprint", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", - ], + ] + if_cuda_is_configured([":delay_kernel_cuda"]), alwayslink = True, ) diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index d0fb598d9fe8ed..63df37d3c037d9 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -63,12 +64,14 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_driver.h" #include "xla/stream_executor/cuda/cuda_event.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/cuda/delay_kernel.h" #include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_command_buffer.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_runtime.h" +#include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -105,6 +108,18 @@ bool FLAGS_prefer_cubin_to_ptx = true; namespace stream_executor { namespace gpu { +namespace { + +bool ShouldLaunchDelayKernel() { + // Only launch the delay kernel if CUDA_LAUNCH_BLOCKING is not set to 1. + static bool value = [] { + const char* blocking = std::getenv("CUDA_LAUNCH_BLOCKING"); + return !blocking || std::string_view{blocking} != "1"; + }(); + return value; +} + +} // namespace static GpuEvent* AsGpuEvent(Event* event) { DCHECK(event != nullptr); return static_cast(event); @@ -148,6 +163,16 @@ absl::Status GpuExecutor::Init() { return absl::OkStatus(); } +absl::StatusOr GpuExecutor::DelayKernelIsSupported(GpuStream* stream) { + // Check the assumption that this device supports unified addressing, + // otherwise skip the delay kernel + TF_ASSIGN_OR_RETURN(int status, + GpuDriver::GetDeviceAttribute( + CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device_)); + + return static_cast(status); +} + absl::Status GpuExecutor::LoadModuleFromCuBin(const char* cubin, CUmodule* module) { uint64_t module_refcount; @@ -268,12 +293,28 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, absl::StatusOr> GpuExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { - // TODO(b/301020144) Move this all to the appropriate Executor class. + GpuSemaphore semaphore{}; + if (!use_delay_kernel) { + LOG(WARNING) + << "Skipping the delay kernel, measurement accuracy will be reduced"; + } + + if (use_delay_kernel && ShouldLaunchDelayKernel()) { + TF_ASSIGN_OR_RETURN(bool is_supported, DelayKernelIsSupported(stream)); + + if (is_supported) { + TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(stream)); + } else { + LOG(WARNING) << "Skipping the delay kernel as it's not supported, " + "measurement accuracy will be reduced."; + } + } TF_ASSIGN_OR_RETURN(auto start_event, CreateGpuEvent(/*allow_timing=*/true)); TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true)); - return GpuTimer::CreateEventBasedTimer( - stream, gpu_context(), use_delay_kernel, std::move(start_event), - std::move(stop_event)); + TF_RETURN_IF_ERROR(start_event->Record(stream->gpu_stream())); + return std::make_unique(gpu_context(), std::move(start_event), + std::move(stop_event), stream, + std::move(semaphore)); } bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { diff --git a/xla/stream_executor/gpu/gpu_timer_kernel.h b/xla/stream_executor/cuda/delay_kernel.h similarity index 74% rename from xla/stream_executor/gpu/gpu_timer_kernel.h rename to xla/stream_executor/cuda/delay_kernel.h index cb0c5d1a3ccff3..09aad2f6e85a67 100644 --- a/xla/stream_executor/gpu/gpu_timer_kernel.h +++ b/xla/stream_executor/cuda/delay_kernel.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ +#ifndef XLA_STREAM_EXECUTOR_CUDA_DELAY_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_CUDA_DELAY_KERNEL_H_ #include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" @@ -22,10 +22,6 @@ limitations under the License. #include "xla/stream_executor/stream.h" namespace stream_executor::gpu { -// Returns true if the current backend and GPU supports the delay kernel for -// time measurement. It might return an error if checking for the support at -// runtime failed. -absl::StatusOr DelayKernelIsSupported(GpuStream* stream); // Launches the delay kernel on the given stream. The caller is responsible for // keeping the returned semaphore alive until the kernel finished executing. @@ -33,4 +29,4 @@ absl::StatusOr DelayKernelIsSupported(GpuStream* stream); absl::StatusOr LaunchDelayKernel(Stream* stream); } // namespace stream_executor::gpu -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ +#endif // XLA_STREAM_EXECUTOR_CUDA_DELAY_KERNEL_H_ diff --git a/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc b/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc similarity index 84% rename from xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc rename to xla/stream_executor/cuda/delay_kernel_cuda.cu.cc index 8a320c583547c6..29035d049f2c31 100644 --- a/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc +++ b/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc @@ -15,10 +15,10 @@ limitations under the License. #include +#include "xla/stream_executor/cuda/delay_kernel.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" -#include "xla/stream_executor/gpu/gpu_timer_kernel.h" #include "xla/stream_executor/typed_kernel_factory.h" namespace stream_executor::gpu { @@ -75,20 +75,6 @@ absl::StatusOr LaunchDelayKernel(Stream* stream) { return semaphore; } -absl::StatusOr DelayKernelIsSupported(GpuStream* stream) { - // Check the assumption that this device supports unified addressing, - // otherwise skip the delay kernel - TF_ASSIGN_OR_RETURN(int status, GpuDriver::GetDeviceAttribute( - CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, - stream->parent()->device())); - if (!status) { - LOG(WARNING) << "Skipping the delay kernel because the device does not " - "support unified addressing"; - } - - return static_cast(status); -} - namespace delay_kernel { void* kernel() { return reinterpret_cast(DelayKernel); } } // namespace delay_kernel diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index c7ff3caf62904a..c65157f3117123 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -361,48 +361,10 @@ gpu_only_cc_library( ], ) -gpu_kernel_library( - name = "gpu_timer_kernel_cuda", - srcs = [ - "gpu_timer_kernel.h", - "gpu_timer_kernel_cuda.cu.cc", - ], - tags = ["manual"], - visibility = internal_visibility([ - "//xla/stream_executor:__subpackages__", - ]), - deps = [ - ":gpu_driver_header", - ":gpu_executor_header", - ":gpu_semaphore", - ":gpu_stream", - "//xla/stream_executor", - "//xla/stream_executor:typed_kernel_factory", - "@com_google_absl//absl/status:statusor", - ], -) - -cc_library( - name = "gpu_timer_kernel_rocm", - srcs = [ - "gpu_timer_kernel.h", - "gpu_timer_kernel_rocm.cc", - ], - tags = ["manual"], - deps = [ - ":gpu_semaphore", - ":gpu_stream", - "//xla/stream_executor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - gpu_only_cc_library( name = "gpu_timer", srcs = [ "gpu_timer.cc", - "gpu_timer_kernel.h", ], hdrs = [ "gpu_timer.h", @@ -426,11 +388,7 @@ gpu_only_cc_library( "@com_google_absl//absl/utility", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - ":gpu_timer_kernel_cuda", - ]) + if_rocm_is_configured([ - ":gpu_timer_kernel_rocm", - ]), + ], ) gpu_only_cc_library( diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index ad4a7b7b3103c9..c19fa1cceeba0c 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -335,6 +335,9 @@ class GpuExecutor : public StreamExecutorCommon { // Creates a GpuEvent for the given stream. absl::StatusOr> CreateGpuEvent(bool allow_timing); + // Returns true if a delay kernel is supported for the given stream. + absl::StatusOr DelayKernelIsSupported(GpuStream* stream); + // Guards the on-disk-module mapping. absl::Mutex disk_modules_mu_; diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index 0f18c4e10b98cd..18b77fb888481b 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -97,7 +97,6 @@ class GpuStream : public StreamCommon { return gpu_stream_; } - GpuExecutor* parent() const { return parent_; } absl::Status WaitFor(Stream* other) override; absl::Status WaitFor(Event* event) override; absl::Status RecordEvent(Event* event) override; diff --git a/xla/stream_executor/gpu/gpu_timer.cc b/xla/stream_executor/gpu/gpu_timer.cc index 9bc8f73ee38dca..1a2475810c9626 100644 --- a/xla/stream_executor/gpu/gpu_timer.cc +++ b/xla/stream_executor/gpu/gpu_timer.cc @@ -35,7 +35,6 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer_kernel.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/stream.h" #include "tsl/platform/errors.h" @@ -56,45 +55,9 @@ absl::Duration RandomDuration() { return absl::Microseconds(distribution(rng)); } -bool ShouldLaunchDelayKernel() { - // Only launch the delay kernel if CUDA_LAUNCH_BLOCKING is not set to 1. - static bool value = [] { - const char* blocking = std::getenv("CUDA_LAUNCH_BLOCKING"); - return !blocking || std::string_view{blocking} != "1"; - }(); - return value; -} - } // namespace -absl::StatusOr> -GpuTimer::CreateEventBasedTimer(GpuStream* stream, GpuContext* context, - bool use_delay_kernel, - std::unique_ptr start_event, - std::unique_ptr stop_event) { - GpuSemaphore semaphore{}; - if (!use_delay_kernel) { - LOG(WARNING) - << "Skipping the delay kernel, measurement accuracy will be reduced"; - } - - if (use_delay_kernel && ShouldLaunchDelayKernel()) { - TF_ASSIGN_OR_RETURN(bool is_supported, DelayKernelIsSupported(stream)); - - if (is_supported) { - TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(stream)); - } - } - - // The start event goes after the delay kernel in the stream - TF_RETURN_IF_ERROR(start_event->Record(stream->gpu_stream())); - - return std::make_unique(context, std::move(start_event), - std::move(stop_event), stream, - std::move(semaphore)); -} - -/*static*/ void GpuTimer::ReturnRandomDurationsForTesting() { +void GpuTimer::ReturnRandomDurationsForTesting() { return_random_durations = true; } diff --git a/xla/stream_executor/gpu/gpu_timer.h b/xla/stream_executor/gpu/gpu_timer.h index f8ce6587e39afd..be0f9a54a2af98 100644 --- a/xla/stream_executor/gpu/gpu_timer.h +++ b/xla/stream_executor/gpu/gpu_timer.h @@ -48,14 +48,9 @@ class GpuStream; // to be measured more accurately. class GpuTimer : public EventBasedTimer { public: - static absl::StatusOr> CreateEventBasedTimer( - GpuStream* stream, GpuContext* context, bool use_delay_kernel, - std::unique_ptr start_event, - std::unique_ptr stop_event); - - explicit GpuTimer(GpuContext* context, std::unique_ptr start_event, - std::unique_ptr stop_event, GpuStream* stream, - GpuSemaphore semaphore = {}) + GpuTimer(GpuContext* context, std::unique_ptr start_event, + std::unique_ptr stop_event, GpuStream* stream, + GpuSemaphore semaphore = {}) : context_(context), start_event_(std::move(start_event)), stop_event_(std::move(stop_event)), diff --git a/xla/stream_executor/gpu/gpu_timer_kernel_rocm.cc b/xla/stream_executor/gpu/gpu_timer_kernel_rocm.cc deleted file mode 100644 index 2ee3680fa3f757..00000000000000 --- a/xla/stream_executor/gpu/gpu_timer_kernel_rocm.cc +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "xla/stream_executor/gpu/gpu_semaphore.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/stream.h" - -namespace stream_executor::gpu { - -absl::StatusOr DelayKernelIsSupported(GpuStream*) { return false; } - -absl::StatusOr LaunchDelayKernel(Stream* stream) { - return absl::UnimplementedError("Not implemented"); -} - -} // namespace stream_executor::gpu diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 6e46312828140c..3ebe531e1ba556 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -186,12 +186,12 @@ GpuExecutor::CreateOrShareConstant(Stream* stream, absl::StatusOr> GpuExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { - // TODO(b/301020144) Move this all to the appropriate Executor class. TF_ASSIGN_OR_RETURN(auto start_event, CreateGpuEvent(/*allow_timing=*/true)); TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true)); - return GpuTimer::CreateEventBasedTimer( - stream, gpu_context(), use_delay_kernel, std::move(start_event), - std::move(stop_event)); + TF_RETURN_IF_ERROR(start_event->Record(stream->gpu_stream())); + return std::make_unique(gpu_context(), std::move(start_event), + std::move(stop_event), stream, + std::move(semaphore)); } bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { @@ -251,6 +251,11 @@ absl::Status GpuExecutor::Init() { return GpuDriver::GetGpuISAVersion(&version_, device_); } +absl::StatusOr GpuExecutor::DelayKernelIsSupported(GpuStream* stream) { + // Delay kernel is not supported on ROCm. + return false; +} + absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, Kernel* kernel) { GpuKernel* rocm_kernel = AsGpuKernel(kernel); From 3b1d733864313a2d7058ac0ad33678cd2893c376 Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Tue, 23 Jul 2024 15:12:54 -0700 Subject: [PATCH 098/376] Refactor HostOffloadLegalize by extracting copy movement and layout update into their own functions. Also, improve comments and debugging logging. PiperOrigin-RevId: 655319917 --- xla/service/host_offload_legalize.cc | 388 +++++++++++++++------------ 1 file changed, 209 insertions(+), 179 deletions(-) diff --git a/xla/service/host_offload_legalize.cc b/xla/service/host_offload_legalize.cc index 7a2d673c0c2387..8887e3d3cd8a51 100644 --- a/xla/service/host_offload_legalize.cc +++ b/xla/service/host_offload_legalize.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -27,11 +28,14 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/service/call_graph.h" #include "xla/service/hlo_value.h" #include "xla/service/host_memory_offload_annotations.h" @@ -45,6 +49,11 @@ namespace xla { namespace { +bool IsEntryComputationParameter(HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kParameter && + instruction->parent()->IsEntryComputation(); +} + constexpr std::array kUsersOpcodes = {HloOpcode::kSlice, HloOpcode::kDynamicSlice}; @@ -342,37 +351,207 @@ absl::StatusOr> WalkDownMemoryOffload( return results; } +void UpdateInstructionLayout(const InstructionAndIndex& instruction_and_index, + const Layout& new_layout) { + HloInstruction* instruction = instruction_and_index.instruction; + const int index = instruction_and_index.index; + VLOG(2) << " Updating " << instruction->name() << "'s layout " + << instruction->shape().ToString(true) << " at index " << index + << " to " << new_layout.ToString(); + // Update shape. Tuple shape vs array shape. + if (index != -1) { + *instruction->mutable_shape() + ->mutable_tuple_shapes(index) + ->mutable_layout() = new_layout; + } else { + VLOG(5) << " Instruction: " << instruction->ToString(); + VLOG(5) << " New layout: " << new_layout.ToString(); + *instruction->mutable_shape()->mutable_layout() = new_layout; + } + VLOG(3) << " Shape is now: " << instruction->shape().ToString(true); + + if (instruction->opcode() == HloOpcode::kWhile) { + // Fix up while body's root instruction shape and condition's + // parameter shape for while loops. + *instruction->while_body() + ->root_instruction() + ->mutable_shape() + ->mutable_tuple_shapes(index) + ->mutable_layout() = new_layout; + *instruction->while_condition() + ->parameter_instruction(0) + ->mutable_shape() + ->mutable_tuple_shapes(index) + ->mutable_layout() = new_layout; + } +} + +absl::Status MoveCopy( + const InstructionAndIndex& copy_to_move_instruction_and_index, + const CallGraph* call_graph, + absl::flat_hash_set& processed_annotations, + absl::flat_hash_set& to_remove) { + HloInstruction* copy_to_move = copy_to_move_instruction_and_index.instruction; + VLOG(5) << "Moving copy: " << copy_to_move->ToString(); + std::vector stack = {copy_to_move_instruction_and_index}; + while (!stack.empty()) { + InstructionAndIndex current_instruction_and_index = stack.back(); + stack.pop_back(); + VLOG(5) << "Current value before down: " + << current_instruction_and_index.instruction->ToString() << " " + << current_instruction_and_index.index; + absl::StatusOr> current_value_down = + WalkDownMemoryOffload(current_instruction_and_index, *call_graph); + if (!current_value_down.ok()) { + VLOG(5) << "Current value down failed: " << current_value_down.status(); + break; + } + for (InstructionAndIndex& instruction_and_index : + current_value_down.value()) { + HloInstruction* instruction = instruction_and_index.instruction; + const int index = instruction_and_index.index; + UpdateInstructionLayout(instruction_and_index, + copy_to_move->operand(0)->shape().layout()); + if (instruction->opcode() == HloOpcode::kParameter) { + std::vector callers = + call_graph->GetComputationCallers(instruction->parent()); + if (callers.size() != 1) { + return absl::InvalidArgumentError( + "Expected to be called only by one caller"); + } + HloInstruction* caller = callers[0]; + UpdateInstructionLayout(InstructionAndIndex(caller, index), + copy_to_move->operand(0)->shape().layout()); + } + } + for (InstructionAndIndex& instruction_and_index : + current_value_down.value()) { + HloInstruction* instruction = instruction_and_index.instruction; + VLOG(5) << "Current value last down: " << instruction->ToString(); + CHECK_NE(instruction->opcode(), HloOpcode::kCopy) + << "Copies should be processed in order"; + if (absl::c_linear_search(kUsersOpcodes, instruction->opcode()) || + instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + HloInstruction* annotation = + FindToDeviceAnnotationToUpdate(instruction); + CHECK_NE(annotation, nullptr) + << "We already verified we could find an annotation here. " + "Something went wrong."; + HloInstruction* new_annotation = nullptr; + if (instruction->opcode() == HloOpcode::kCustomCall) { + new_annotation = annotation; + } else { + new_annotation = + instruction->AddInstruction(annotation->CloneWithNewOperands( + instruction->shape(), {instruction})); + } + UpdateInstructionLayout(InstructionAndIndex(new_annotation, -1), + copy_to_move->operand(0)->shape().layout()); + Shape new_copy_shape = new_annotation->shape(); + *new_copy_shape.mutable_layout() = copy_to_move->shape().layout(); + HloInstruction* new_copy = + instruction->AddInstruction(copy_to_move->CloneWithNewOperands( + new_copy_shape, {new_annotation})); + std::vector users = instruction->users(); + for (HloInstruction* use : users) { + if (use == new_copy || use == new_annotation) { + continue; + } + TF_RETURN_IF_ERROR( + instruction->ReplaceUseWithDifferentShape(use, new_copy)); + } + // Move the copy here. + if (new_annotation != annotation) { + TF_RETURN_IF_ERROR(annotation->ReplaceAllUsesWithDifferentShape( + annotation->mutable_operand(0))); + to_remove.insert(annotation); + } + continue; + } + // Move the annotation first just before dynamic-update-slice to avoid + // shape changes. + if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { + HloInstruction* annotation = + FindToHostAnnotationToUpdate(instruction->mutable_operand(1)); + if (annotation == nullptr) { + return absl::InternalError("Annotation not found."); + } + CHECK(annotation->opcode() == HloOpcode::kCustomCall); + HloInstruction* new_annotation = + instruction->AddInstruction(annotation->CloneWithNewOperands( + instruction->operand(1)->shape(), + {instruction->mutable_operand(1)})); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(1, new_annotation)); + TF_RETURN_IF_ERROR( + annotation->ReplaceAllUsesWith(annotation->mutable_operand(0))); + processed_annotations.insert(annotation); + processed_annotations.insert(new_annotation); + to_remove.insert(annotation); + + // Need to make DUS and its update slice's layout consistent by adding + // a copy on the operand side, which is on device. + if (instruction->shape().layout().minor_to_major() != + instruction->operand(1)->shape().layout().minor_to_major()) { + HloInstruction* update_slice = instruction->mutable_operand(1); + CHECK(update_slice->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)); + *update_slice->mutable_shape()->mutable_layout() = + instruction->shape().layout(); + HloInstruction* new_copy = + update_slice->AddInstruction(HloInstruction::CreateUnary( + update_slice->shape(), HloOpcode::kCopy, + update_slice->mutable_operand(0))); + TF_RETURN_IF_ERROR(update_slice->ReplaceOperandWith(0, new_copy)); + } + } + stack.push_back(instruction_and_index); + } + } + VLOG(5) << "MOVED: " << copy_to_move->ToString(); + TF_RETURN_IF_ERROR(copy_to_move->ReplaceAllUsesWithDifferentShape( + copy_to_move->mutable_operand(0))); + TF_RETURN_IF_ERROR(copy_to_move->parent()->RemoveInstruction(copy_to_move)); + return absl::OkStatus(); +} + absl::StatusOr ProcessAnnotationForCopyMovement( HloInstruction* instruction, const CallGraph* call_graph, absl::flat_hash_set& processed_annotations, - std::vector& to_remove) { - auto is_entry_computation_parameter = [](HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kParameter && - instruction->parent()->IsEntryComputation(); - }; - + absl::flat_hash_set& to_remove) { + VLOG(2) << "Walking down graph starting at instruction " + << instruction->name(); if (instruction->IsRoot()) { return false; } if (instruction->user_count() == 0) { return false; } + // Look for a DynamicUpdateSlice following this instruction. HloInstruction* starting_instr = FindDUSFromAnnotation(instruction->users().at(0)); // If it's the pure copy case reset instruction. if (starting_instr->opcode() != HloOpcode::kDynamicUpdateSlice) { starting_instr = instruction; } - VLOG(3) << "Dus or Annotation: " << starting_instr->ToString(); + if (!(starting_instr->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget) || + IsEntryComputationParameter(starting_instr) || + starting_instr->opcode() == HloOpcode::kDynamicUpdateSlice)) { + return absl::InternalError( + "Starting instruction must be a move-to-host annotation, entry " + "computation parameter, or dynamic-update-slice."); + } + VLOG(2) << "Effective starting instruction: " << starting_instr->name(); + InstructionAndIndex current_value(starting_instr, -1); - // Found a copy that would block offloading. Walk up to find all annotations - // to update (required in case there are multiple insertions in the buffer). + // Walk up to find all annotations to update (required in case there are + // multiple insertions in the buffer). processed_annotations.insert(current_value.instruction); - if (!current_value.instruction->IsCustomCall( - host_memory_offload_annotations::kMoveToHostCustomCallTarget) && - !is_entry_computation_parameter(current_value.instruction)) { - CHECK_EQ(current_value.instruction->opcode(), - HloOpcode::kDynamicUpdateSlice); + + if (current_value.instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { + // Walk up the graph and find the broadcast which this dynamic-update-slice + // is updating. while (true) { VLOG(10) << "Current value before: " << current_value.instruction->ToString(); @@ -402,9 +581,9 @@ absl::StatusOr ProcessAnnotationForCopyMovement( } } } - // Do a final walkdown from the top to collect all the instructions that need - // their shape updated. + // Do a final walkdown from the top to find all the copies which we need to + // move. std::vector copies_to_move; std::vector stack = {current_value}; while (!stack.empty()) { @@ -475,173 +654,19 @@ absl::StatusOr ProcessAnnotationForCopyMovement( VLOG(5) << "Current value last down: " << stack.back().instruction->ToString(); if (instruction_and_index.instruction->opcode() == HloOpcode::kCopy) { + VLOG(1) << absl::StreamFormat( + " Found a copy to move: \"%s\"", + instruction_and_index.instruction->name()); copies_to_move.push_back(instruction_and_index); } } } - auto update_shape_layout = - [&](const InstructionAndIndex& instruction_and_index, - HloInstruction* copy_to_move) { - HloInstruction* instruction = instruction_and_index.instruction; - const int index = instruction_and_index.index; - VLOG(5) << "Update shape layout: " << instruction->ToString() << " " - << index; - // Update shape. Tuple shape vs array shape. - if (index != -1) { - *instruction->mutable_shape() - ->mutable_tuple_shapes(index) - ->mutable_layout() = copy_to_move->operand(0)->shape().layout(); - } else { - *instruction->mutable_shape()->mutable_layout() = - copy_to_move->operand(0)->shape().layout(); - } - - if (instruction->opcode() == HloOpcode::kWhile) { - // Fix up while body's root instruction shape and condition's - // parameter shape for while loops. - Shape new_shape = copy_to_move->operand(0)->shape(); - *instruction->while_body() - ->root_instruction() - ->mutable_shape() - ->mutable_tuple_shapes(index) - ->mutable_layout() = new_shape.layout(); - *instruction->while_condition() - ->parameter_instruction(0) - ->mutable_shape() - ->mutable_tuple_shapes(index) - ->mutable_layout() = new_shape.layout(); - } - }; - // Process all copies one at a time from the last to the first and push it to // its specific user. - while (!copies_to_move.empty()) { - InstructionAndIndex& copy_to_move_instruction_and_index = - copies_to_move.back(); - HloInstruction* copy_to_move = - copy_to_move_instruction_and_index.instruction; - VLOG(5) << "Copy to move: " << copy_to_move->ToString(); - stack.clear(); - stack.push_back(copy_to_move_instruction_and_index); - while (!stack.empty()) { - VLOG(5) << "Current value before down: " - << stack.back().instruction->ToString() << " " - << stack.back().index; - absl::StatusOr> current_value_down = - WalkDownMemoryOffload(stack.back(), *call_graph); - if (!current_value_down.ok()) { - VLOG(5) << "Current value down failed: " << current_value_down.status(); - break; - } - for (InstructionAndIndex& instruction_and_index : - current_value_down.value()) { - HloInstruction* instruction = instruction_and_index.instruction; - const int index = instruction_and_index.index; - update_shape_layout(instruction_and_index, copy_to_move); - if (instruction->opcode() == HloOpcode::kParameter) { - std::vector callers = - call_graph->GetComputationCallers(instruction->parent()); - if (callers.size() != 1) { - return absl::InvalidArgumentError( - "Expected to be called only by one caller"); - } - HloInstruction* caller = callers[0]; - update_shape_layout(InstructionAndIndex(caller, index), copy_to_move); - } - } - stack.pop_back(); - for (InstructionAndIndex& instruction_and_index : - current_value_down.value()) { - HloInstruction* instruction = instruction_and_index.instruction; - VLOG(5) << "Current value last down: " << instruction->ToString(); - CHECK_NE(instruction->opcode(), HloOpcode::kCopy) - << "Copies should be processed in order"; - if (absl::c_linear_search(kUsersOpcodes, instruction->opcode()) || - instruction->IsCustomCall(host_memory_offload_annotations:: - kMoveToDeviceCustomCallTarget)) { - HloInstruction* annotation = - FindToDeviceAnnotationToUpdate(instruction); - CHECK_NE(annotation, nullptr) - << "We already verified we could find an annotation here. " - "Something went wrong."; - HloInstruction* new_annotation = nullptr; - if (instruction->opcode() == HloOpcode::kCustomCall) { - new_annotation = annotation; - } else { - new_annotation = - instruction->AddInstruction(annotation->CloneWithNewOperands( - instruction->shape(), {instruction})); - } - update_shape_layout(InstructionAndIndex(new_annotation, -1), - copy_to_move); - Shape new_copy_shape = new_annotation->shape(); - *new_copy_shape.mutable_layout() = copy_to_move->shape().layout(); - HloInstruction* new_copy = - instruction->AddInstruction(copy_to_move->CloneWithNewOperands( - new_copy_shape, {new_annotation})); - std::vector users = instruction->users(); - for (HloInstruction* use : users) { - if (use == new_copy || use == new_annotation) { - continue; - } - TF_RETURN_IF_ERROR( - instruction->ReplaceUseWithDifferentShape(use, new_copy)); - } - // Move the copy here. - if (new_annotation != annotation) { - TF_RETURN_IF_ERROR(annotation->ReplaceAllUsesWithDifferentShape( - annotation->mutable_operand(0))); - to_remove.push_back(annotation); - } - continue; - } - // Move the annotation first just before dynamic-update-slice to avoid - // shape changes. - if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { - HloInstruction* annotation = - FindToHostAnnotationToUpdate(instruction->mutable_operand(1)); - if (annotation == nullptr) { - CHECK(false); - return false; - } - CHECK(annotation->opcode() == HloOpcode::kCustomCall); - HloInstruction* new_annotation = - instruction->AddInstruction(annotation->CloneWithNewOperands( - instruction->operand(1)->shape(), - {instruction->mutable_operand(1)})); - TF_RETURN_IF_ERROR( - instruction->ReplaceOperandWith(1, new_annotation)); - TF_RETURN_IF_ERROR( - annotation->ReplaceAllUsesWith(annotation->mutable_operand(0))); - processed_annotations.insert(annotation); - processed_annotations.insert(new_annotation); - to_remove.push_back(annotation); - - // Need to make DUS and its update slice's layout consistent by adding - // a copy on the operand side, which is on device. - if (instruction->shape().layout().minor_to_major() != - instruction->operand(1)->shape().layout().minor_to_major()) { - HloInstruction* update_slice = instruction->mutable_operand(1); - CHECK(update_slice->IsCustomCall( - host_memory_offload_annotations::kMoveToHostCustomCallTarget)); - *update_slice->mutable_shape()->mutable_layout() = - instruction->shape().layout(); - HloInstruction* new_copy = - update_slice->AddInstruction(HloInstruction::CreateUnary( - update_slice->shape(), HloOpcode::kCopy, - update_slice->mutable_operand(0))); - TF_RETURN_IF_ERROR(update_slice->ReplaceOperandWith(0, new_copy)); - } - } - stack.push_back(instruction_and_index); - } - } - VLOG(5) << "MOVED: " << copy_to_move->ToString(); - TF_RETURN_IF_ERROR(copy_to_move->ReplaceAllUsesWithDifferentShape( - copy_to_move->mutable_operand(0))); - TF_RETURN_IF_ERROR(copy_to_move->parent()->RemoveInstruction(copy_to_move)); - copies_to_move.pop_back(); + for (auto it = copies_to_move.rbegin(); it != copies_to_move.rend(); ++it) { + TF_RETURN_IF_ERROR( + MoveCopy(*it, call_graph, processed_annotations, to_remove)); } return true; } @@ -651,7 +676,7 @@ absl::StatusOr FixupInterveningCopies( const std::vector& starting_instructions, const CallGraph* call_graph) { absl::flat_hash_set processed_annotations; - std::vector annotations_to_remove; + absl::flat_hash_set annotations_to_remove; bool changed = false; for (HloInstruction* instruction : starting_instructions) { if (processed_annotations.contains(instruction)) { @@ -680,8 +705,7 @@ HostOffloadLegalize::FindStartingInstructionsOfHostMemoryOffload( for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kParameter && - instruction->parent()->IsEntryComputation()) { + if (IsEntryComputationParameter(instruction)) { Shape param_shape = module->entry_computation_layout() .parameter_layout(instruction->parameter_number()) @@ -725,6 +749,12 @@ absl::StatusOr HostOffloadLegalize::Run( // any are found, move them outside of the offload section. std::vector starting_instructions = FindStartingInstructionsOfHostMemoryOffload(module, execution_threads); + VLOG(1) << absl::StreamFormat( + "Starting instructions for host memory offload: [%s]", + absl::StrJoin(starting_instructions, ", ", + [](std::string* out, HloInstruction* instruction) { + return absl::StrAppend(out, instruction->name()); + })); std::unique_ptr call_graph = CallGraph::Build(module); TF_ASSIGN_OR_RETURN( bool changed_intervening_copies, From e9b0329f305ba71865d186cc197b4cea0a2a5998 Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Tue, 23 Jul 2024 15:51:45 -0700 Subject: [PATCH 099/376] Improve log statements in HostOffloadLegalize. Combine two loops over the same indices. PiperOrigin-RevId: 655332090 --- xla/service/host_offload_legalize.cc | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/xla/service/host_offload_legalize.cc b/xla/service/host_offload_legalize.cc index 8887e3d3cd8a51..1199112adeb9f7 100644 --- a/xla/service/host_offload_legalize.cc +++ b/xla/service/host_offload_legalize.cc @@ -248,9 +248,8 @@ absl::StatusOr> WalkDownMemoryOffload( const InstructionAndIndex& current_value, const CallGraph& call_graph) { // TODO(maggioni): Verify that set of instructions supported in chain by // legalization is in sync with host_offloader. - VLOG(5) << "Current value in progress: " - << current_value.instruction->ToString() - << " idx: " << current_value.index; + VLOG(6) << "Getting users of: \"" << current_value.instruction->ToString() + << "\" at index " << current_value.index; std::vector results; auto add_gte_for_idx = [&results](HloInstruction* instr, int idx) -> absl::Status { @@ -397,18 +396,20 @@ absl::Status MoveCopy( while (!stack.empty()) { InstructionAndIndex current_instruction_and_index = stack.back(); stack.pop_back(); - VLOG(5) << "Current value before down: " + VLOG(5) << "Current top of stack: " << current_instruction_and_index.instruction->ToString() << " " << current_instruction_and_index.index; absl::StatusOr> current_value_down = WalkDownMemoryOffload(current_instruction_and_index, *call_graph); if (!current_value_down.ok()) { - VLOG(5) << "Current value down failed: " << current_value_down.status(); + VLOG(5) << "WalkDownMemoryOffload failed: " + << current_value_down.status(); break; } for (InstructionAndIndex& instruction_and_index : current_value_down.value()) { HloInstruction* instruction = instruction_and_index.instruction; + VLOG(5) << "Evaluating successor: " << instruction->ToString(); const int index = instruction_and_index.index; UpdateInstructionLayout(instruction_and_index, copy_to_move->operand(0)->shape().layout()); @@ -423,11 +424,7 @@ absl::Status MoveCopy( UpdateInstructionLayout(InstructionAndIndex(caller, index), copy_to_move->operand(0)->shape().layout()); } - } - for (InstructionAndIndex& instruction_and_index : - current_value_down.value()) { - HloInstruction* instruction = instruction_and_index.instruction; - VLOG(5) << "Current value last down: " << instruction->ToString(); + CHECK_NE(instruction->opcode(), HloOpcode::kCopy) << "Copies should be processed in order"; if (absl::c_linear_search(kUsersOpcodes, instruction->opcode()) || @@ -453,6 +450,8 @@ absl::Status MoveCopy( HloInstruction* new_copy = instruction->AddInstruction(copy_to_move->CloneWithNewOperands( new_copy_shape, {new_annotation})); + VLOG(2) << absl::StreamFormat("Inserting copy \"%s\" after \"%s\"", + new_copy->name(), instruction->name()); std::vector users = instruction->users(); for (HloInstruction* use : users) { if (use == new_copy || use == new_annotation) { @@ -508,7 +507,8 @@ absl::Status MoveCopy( stack.push_back(instruction_and_index); } } - VLOG(5) << "MOVED: " << copy_to_move->ToString(); + VLOG(2) << absl::StreamFormat("Removing copy \"%s\"", + copy_to_move->ToString()); TF_RETURN_IF_ERROR(copy_to_move->ReplaceAllUsesWithDifferentShape( copy_to_move->mutable_operand(0))); TF_RETURN_IF_ERROR(copy_to_move->parent()->RemoveInstruction(copy_to_move)); From cbe9cb011b9a834cd953e5a8a5544cfe0e91e815 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 23 Jul 2024 16:24:11 -0700 Subject: [PATCH 100/376] [XLA:GPU] Move pipeline parallelism tests to separate test file. PiperOrigin-RevId: 655342571 --- xla/tests/BUILD | 42 ++++++ xla/tests/collective_ops_test.cc | 71 ---------- .../collective_pipeline_parallelism_test.cc | 131 ++++++++++++++++++ 3 files changed, 173 insertions(+), 71 deletions(-) create mode 100644 xla/tests/collective_pipeline_parallelism_test.cc diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 701ed1fca2abb0..6463beff342d19 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2248,6 +2248,48 @@ xla_test( ], ) +xla_test( + name = "collective_pipeline_parallelism_test", + srcs = ["collective_pipeline_parallelism_test.cc"], + args = ["--xla_force_host_platform_device_count=4"], + backend_tags = { + # This test is tagged "manual" because it requires multiple GPUs, and Forge only supports + # single-GPU tests. Guitar skips "manual" tests unless they're also tagged "guitar". + "gpu": [ + "guitar", + "manual", + "multi_gpu", + "no_oss", + "notap", + ], + "cpu": [ + "notsan", + ], + }, + backends = [ + "gpu", + "cpu", + ], + deps = [ + ":hlo_test_base", + ":literal_test_util", + ":test_macros_header", + ":test_utils", + ":verified_hlo_module", + ":xla_internal_test_main", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", + ], +) + xla_test( name = "collective_ops_e2e_test", srcs = ["collective_ops_e2e_test.cc"], diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index ab7e7a40898bf1..924a4b8c1775cc 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -733,77 +733,6 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { results[3])); } -XLA_TEST_F(CollectiveOpsTest, - CollectivePermute_CircularPipelinePreOptimization) { - const absl::string_view kModuleStr = R"( - HloModule test - - while_cond { - param = (u32[], f32[2,2], f32[2,2]) parameter(0) - iter = u32[] get-tuple-element(param), index=0 - max_iter = u32[] constant(3) - ROOT cmp = pred[] compare(iter, max_iter), direction=LT - } - - while_body { - param = (u32[], f32[2,2], f32[2,2]) parameter(0) - iter = u32[] get-tuple-element(param), index=0 - data = f32[2,2] get-tuple-element(param), index=1 - weights = f32[2,2] get-tuple-element(param), index=2 - matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, rhs_contracting_dims={0} - cp = f32[2,2] collective-permute(matmul), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} - iter_increment = u32[] constant(1) - next_iter = u32[] add(iter, iter_increment) - ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, cp, weights) - } - - ENTRY test_computation { - iter = u32[] constant(0) - data = f32[2,2] parameter(0) - weights = f32[2,2] parameter(1) - input = (u32[], f32[2,2], f32[2,2]) tuple(iter, data, weights) - while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body - ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1 - } - )"; - const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) - - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - std::unique_ptr module; - TF_ASSERT_OK_AND_ASSIGN(module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - - // input for replica i is - // {{i, i}, - // {i, i}} - std::vector replica_inputs; - for (float i = 1; i < kNumReplicas + 1; ++i) { - replica_inputs.push_back({LiteralUtil::CreateR2({{i, i}, {i, i}})}); - replica_inputs.push_back(LiteralUtil::CreateR2({{0, 0}, {0, 1}})); - } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, - test_runner_.CreateExecutable( - std::unique_ptr(std::move(module)), - /*run_hlo_passes=*/true)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated( - /*executable_provider=*/[&](int64_t) { return executable.get(); }, - /*argument_count_provider=*/[](int64_t) { return 2; }, - /*argument_provider=*/ - [&](int64_t replica, int64_t index) -> const Literal* { - return &replica_inputs[replica * 2 + index]; - }, - kNumReplicas, /*run_hlo_passes=*/true, - /*device_assignment=*/nullptr)); - LiteralTestUtil::ExpectR2Equal({{0, 0}, {2, 2}}, results[0]); - LiteralTestUtil::ExpectR2Equal({{0, 0}, {3, 3}}, results[1]); - LiteralTestUtil::ExpectR2Equal({{0, 0}, {4, 4}}, results[2]); - LiteralTestUtil::ExpectR2Equal({{0, 0}, {1, 1}}, results[3]); -} - XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degenerate) { const char* const kModuleStr = R"( HloModule test diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc new file mode 100644 index 00000000000000..c5fe51401e8945 --- /dev/null +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -0,0 +1,131 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/executable.h" +#include "xla/service/hlo_module_config.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tests/test_macros.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" + +// Tests cross-GPU operations. +// +// Several tests requires at least four GPUs. For instructions on running this +// within Google, see go/multi-gpu-unit-test. + +// TODO: Move this to hlo_test_base.h +#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \ + if (num_devices_ < x) { \ + GTEST_SKIP() << "Test requires at least " << x << " devices"; \ + } + +namespace xla { +namespace { + +class CollectivePipelineParallelismTest : public HloTestBase { + public: + CollectivePipelineParallelismTest() : num_devices_(backend().device_count()) { + VLOG(1) << "Running with " << num_devices_ << " devices"; + } + + protected: + const int64_t num_devices_; +}; + +XLA_TEST_F(CollectivePipelineParallelismTest, + CollectivePermute_CircularPipelinePreOptimization) { + const absl::string_view kModuleStr = R"( + HloModule test + + while_cond { + param = (u32[], f32[2,2], f32[2,2]) parameter(0) + iter = u32[] get-tuple-element(param), index=0 + max_iter = u32[] constant(3) + ROOT cmp = pred[] compare(iter, max_iter), direction=LT + } + + while_body { + param = (u32[], f32[2,2], f32[2,2]) parameter(0) + iter = u32[] get-tuple-element(param), index=0 + data = f32[2,2] get-tuple-element(param), index=1 + weights = f32[2,2] get-tuple-element(param), index=2 + matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, rhs_contracting_dims={0} + cp = f32[2,2] collective-permute(matmul), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + iter_increment = u32[] constant(1) + next_iter = u32[] add(iter, iter_increment) + ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, cp, weights) + } + + ENTRY test_computation { + iter = u32[] constant(0) + data = f32[2,2] parameter(0) + weights = f32[2,2] parameter(1) + input = (u32[], f32[2,2], f32[2,2]) tuple(iter, data, weights) + while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body + ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1 + } + )"; + const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + std::unique_ptr module; + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // Input for replica i is + // {{i, i}, + // {i, i}}. + std::vector replica_inputs; + for (float i = 1; i < kNumReplicas + 1; ++i) { + replica_inputs.push_back({LiteralUtil::CreateR2({{i, i}, {i, i}})}); + replica_inputs.push_back(LiteralUtil::CreateR2({{0, 0}, {0, 1}})); + } + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + test_runner_.CreateExecutable( + std::unique_ptr(std::move(module)), + /*run_hlo_passes=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated( + /*executable_provider=*/[&](int64_t) { return executable.get(); }, + /*argument_count_provider=*/[](int64_t) { return 2; }, + /*argument_provider=*/ + [&](int64_t replica, int64_t index) -> const Literal* { + return &replica_inputs[replica * 2 + index]; + }, + kNumReplicas, /*run_hlo_passes=*/true, + /*device_assignment=*/nullptr)); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {2, 2}}, results[0]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {3, 3}}, results[1]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {4, 4}}, results[2]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {1, 1}}, results[3]); +} + +} // namespace +} // namespace xla From b9bd3b17daa307a5ab8d2df3baa7816f7a6d3f7b Mon Sep 17 00:00:00 2001 From: Seher Ellis Date: Tue, 23 Jul 2024 16:25:49 -0700 Subject: [PATCH 101/376] [XLA:LHS] Improve logging for kShareable resource occupiers. PiperOrigin-RevId: 655343018 --- xla/service/latency_hiding_scheduler.cc | 54 ++++++++++++++----------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index c7bdae2f620db4..a17bc63b6f8804 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -1292,10 +1292,10 @@ void DefaultSchedulerCore::LogInstruction(const HloInstruction* instr) const { void PrintOccupierList( std::vector>& occupiers) { - VLOG(2) << "Occupier list:"; for (int64_t i = 0; i < occupiers.size(); i++) { - VLOG(2) << "\tOccupier at index: " << i - << " with projected finish time: " << occupiers[i].second + VLOG(2) << "\tOccupier " << i << ": " + << occupiers[i].first->Target().GetInstr().name() + << ", projected finish time: " << occupiers[i].second << " original latency: " << occupiers[i].first->OriginalLatency() << " latency: " << occupiers[i].first->Latency(); } @@ -1346,9 +1346,6 @@ bool DefaultSchedulerCore::DeleteOccupierFromResource( for (; it != occupiers.end(); it++) { it->second -= accumulated_saved_time; } - if (VLOG_IS_ON(2)) { - PrintOccupierList(occupiers); - } return true; } @@ -1400,9 +1397,6 @@ bool DefaultSchedulerCore::AddOccupierToResource( for (; it != occupiers.end(); it++) { it->second += accumulated_delay; } - if (VLOG_IS_ON(2)) { - PrintOccupierList(occupiers); - } return true; } @@ -1465,13 +1459,13 @@ absl::StatusOr DefaultSchedulerCore::ScheduleNode( auto occupiers = sched_state->shareable_resource_occupiers[resource]; for (auto [occupier_edge, edge_pft] : occupiers) { if (occupier_edge == &pred) { - VLOG(10) << "Ready time of scheduled node " << n->GetInstr().name() - << " before update with pft: " << edge_pft - << ", ready_time: " << schedule_time; + VLOG(2) << "Ready time of scheduled node " << n->GetInstr().name() + << " before update with pft: " << edge_pft + << ", ready_time: " << schedule_time; schedule_time = std::max(schedule_time, edge_pft); - VLOG(10) << "Ready time of scheduled node " << n->GetInstr().name() - << " after update with pft: " << edge_pft - << ", ready_time: " << schedule_time; + VLOG(2) << "Ready time of scheduled node " << n->GetInstr().name() + << " after update with pft: " << edge_pft + << ", ready_time: " << schedule_time; } } } @@ -1489,6 +1483,13 @@ absl::StatusOr DefaultSchedulerCore::ScheduleNode( CHECK(DeleteOccupierFromResource( schedule_time, edge, sched_state->shareable_resource_occupiers[resource])); + if (VLOG_IS_ON(2)) { + VLOG(2) << "Occupier list for " + << sched_state->async_tracker->GetResourceName(resource) + << ": "; + PrintOccupierList( + sched_state->shareable_resource_occupiers[resource]); + } } } // If a shareable resource is occupied by scheduling this node, insert the @@ -1502,6 +1503,13 @@ absl::StatusOr DefaultSchedulerCore::ScheduleNode( CHECK(AddOccupierToResource( current_time, inverse_edge, sched_state->shareable_resource_occupiers[resource])); + if (VLOG_IS_ON(2)) { + VLOG(2) << "Occupier list for " + << sched_state->async_tracker->GetResourceName(resource) + << ": "; + PrintOccupierList( + sched_state->shareable_resource_occupiers[resource]); + } } break; } @@ -1551,15 +1559,15 @@ absl::StatusOr DefaultSchedulerCore::ScheduleNode( auto occupiers = sched_state->shareable_resource_occupiers[resource]; for (auto [occupier_edge, edge_pft] : occupiers) { if (occupier_edge == &pred) { - VLOG(10) << "Ready time of predecessor " - << edge.Target().GetInstr().name() - << " before update with pft: " << edge_pft - << ", ready_time: " << ready_time; + VLOG(2) << "Ready time of predecessor " + << edge.Target().GetInstr().name() + << " before update with pft: " << edge_pft + << ", ready_time: " << ready_time; ready_time = std::max(ready_time, edge_pft); - VLOG(10) << "Ready time of predecessor " - << edge.Target().GetInstr().name() - << " after update with pft: " << edge_pft - << ", ready_time: " << ready_time; + VLOG(2) << "Ready time of predecessor " + << edge.Target().GetInstr().name() + << " after update with pft: " << edge_pft + << ", ready_time: " << ready_time; } } } From cff1aa7e372d79a4c9e161818ac93e98e4f791f3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 23 Jul 2024 17:47:49 -0700 Subject: [PATCH 102/376] [XLA:MSA] Implement an auxiliary function (SimulateAsyncCopyDone) to simulate the overhead of processing copy-done instruction. This CL implements a function to simulate the overhead of copy-done instructions. Specifically, there are two directions which share the bandwidth: default-read and default-write. Two directions will share the bandwidth equally. For example, when we process a default-read request, if there are also outstanding default-write process, we can only use half of the full bandwidth to process requests in each direction. PiperOrigin-RevId: 655365388 --- xla/service/memory_space_assignment/BUILD | 4 + .../memory_space_assignment.cc | 3 +- .../memory_space_assignment/simulator.cc | 124 +++++++++++ .../memory_space_assignment/simulator.h | 77 ++++++- .../memory_space_assignment/simulator_test.cc | 197 ++++++++++++++++-- 5 files changed, 387 insertions(+), 18 deletions(-) diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index ca31c55bf92b8f..b886ce775abf9e 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -332,7 +332,10 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:hlo_value", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", ], ) @@ -349,6 +352,7 @@ xla_cc_test( "//xla/service:hlo_alias_analysis", "//xla/service:hlo_cost_analysis", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", diff --git a/xla/service/memory_space_assignment/memory_space_assignment.cc b/xla/service/memory_space_assignment/memory_space_assignment.cc index fc4b570f665bab..62d434983bd65a 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -354,7 +354,8 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis)); if (options_.cost_analysis) { - RuntimeSimulator runtime_simulator(options_.cost_analysis); + RuntimeSimulator runtime_simulator(options_.cost_analysis, + options_.alternate_memory_space); float estimated_time = runtime_simulator.ComputeEstimatedElapsedTime( hlo_live_range, allocations_); VLOG(1) << "Estimated elapsed time (sec): " << estimated_time; diff --git a/xla/service/memory_space_assignment/simulator.cc b/xla/service/memory_space_assignment/simulator.cc index f061f12c5a4491..761dae2983366f 100644 --- a/xla/service/memory_space_assignment/simulator.cc +++ b/xla/service/memory_space_assignment/simulator.cc @@ -15,15 +15,22 @@ limitations under the License. #include "xla/service/memory_space_assignment/simulator.h" +#include #include +#include #include +#include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_live_range.h" +#include "xla/layout.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/shape_util.h" @@ -86,5 +93,122 @@ float RuntimeSimulator::ComputeEstimatedElapsedTime( } return total_elapsed; } + +MemoryTransferDirection GetAsyncCopyDirection( + const HloInstruction* async_copy_start, int64_t alternate_memory_space) { + CHECK_EQ(async_copy_start->opcode(), HloOpcode::kCopyStart); + + int64_t operand_memory_space = + async_copy_start->operand(0)->shape().layout().memory_space(); + + // Get all users + std::optional output_memory_space; + for (const HloInstruction* user : async_copy_start->users()) { + if (user->opcode() == HloOpcode::kCopyDone) { + output_memory_space.emplace(user->shape().layout().memory_space()); + break; + } + } + if (!output_memory_space.has_value()) { + return MemoryTransferDirection::kUnsupported; + } + + if (operand_memory_space == xla::Layout::kDefaultMemorySpace && + output_memory_space == alternate_memory_space) { + return MemoryTransferDirection::kDefaultToAlternate; + } + if (operand_memory_space == alternate_memory_space && + output_memory_space == xla::Layout::kDefaultMemorySpace) { + return MemoryTransferDirection::kAlternateToDefault; + } + return MemoryTransferDirection::kUnsupported; +} + +const std::list& +RuntimeSimulator::GetOutstandingReadDefaultQueue() const { + return outstanding_read_default_queue_; +} + +const std::list& +RuntimeSimulator::GetOutstandingWriteDefaultQueue() const { + return outstanding_write_default_queue_; +} + +const HloInstruction* RuntimeSimulator::RemoveBytesFromQueueIfNotEmpty( + std::list& async_copy_queue, float processed_bytes) { + if (async_copy_queue.empty()) return nullptr; + CHECK_GE(async_copy_queue.front().remaining_bytes_to_transfer, + processed_bytes); + async_copy_queue.front().remaining_bytes_to_transfer -= processed_bytes; + if (async_copy_queue.front().remaining_bytes_to_transfer == 0.0) { + const HloInstruction* retired_instruction = + async_copy_queue.front().copy_start_inst; + async_copy_queue.pop_front(); + return retired_instruction; + } + return nullptr; +} + +float RuntimeSimulator::SimulateAsyncCopyDone( + const HloInstruction* copy_done_instruction) { + const HloInstruction* copy_start_instruction = + copy_done_instruction->operand(0); + MemoryTransferDirection direction = + GetAsyncCopyDirection(copy_start_instruction, alternate_memory_space_); + if (direction == MemoryTransferDirection::kUnsupported) { + // The memory access is not a default <-> alternate memory copy. + LOG(WARNING) << "Unsupported memory transfer direction for copy-done: " + << copy_done_instruction->ToString(); + return 0.0; + } + std::list& same_direction_queue = + direction == MemoryTransferDirection::kDefaultToAlternate + ? outstanding_read_default_queue_ + : outstanding_write_default_queue_; + std::list& opposite_direction_queue = + direction == MemoryTransferDirection::kDefaultToAlternate + ? outstanding_write_default_queue_ + : outstanding_read_default_queue_; + + if (absl::c_find_if( + same_direction_queue, [&](const OutstandingAsyncCopy& async_copy) { + return async_copy.copy_start_inst == copy_start_instruction; + }) == same_direction_queue.end()) { + // The copy has already finished; thus, the copy-done takes no time. + return 0.0; + } + + // Each iteration of the while loop simulates transferring a number of + // bytes from each queue that is equal to the smaller of the two elements + // at the front of each queue. If that causes us to finish a copy in the + // same_direction_queue, and that copy is the copy_done_instruction, we + // break the loop. + float elapsed_time = 0.0; + const HloInstruction* retired_instruction_in_same_direction_queue = nullptr; + // Loop until we process the copy start instruction that the copy-done + // instruction is waiting for. + do { + float bytes_to_process = + same_direction_queue.front().remaining_bytes_to_transfer; + float available_bandwidth = cost_analysis_->base_costs().BytesPerSecond(); + + if (!opposite_direction_queue.empty()) { + // Need to share the bandwidth with the opposite direction queue. + available_bandwidth *= 0.5; + bytes_to_process = std::min( + bytes_to_process, + opposite_direction_queue.front().remaining_bytes_to_transfer); + } + + elapsed_time += bytes_to_process / available_bandwidth; + + RemoveBytesFromQueueIfNotEmpty(opposite_direction_queue, bytes_to_process); + retired_instruction_in_same_direction_queue = + RemoveBytesFromQueueIfNotEmpty(same_direction_queue, bytes_to_process); + } while (retired_instruction_in_same_direction_queue != + copy_start_instruction); + return elapsed_time; +}; + } // namespace memory_space_assignment } // namespace xla diff --git a/xla/service/memory_space_assignment/simulator.h b/xla/service/memory_space_assignment/simulator.h index cd969615fd7bfa..146d569fbe66f0 100644 --- a/xla/service/memory_space_assignment/simulator.h +++ b/xla/service/memory_space_assignment/simulator.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SIMULATOR_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SIMULATOR_H_ +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/service/memory_space_assignment/cost_analysis.h" @@ -23,12 +27,49 @@ limitations under the License. namespace xla { namespace memory_space_assignment { +enum class MemoryTransferDirection { + kUnsupported, + kDefaultToAlternate, + kAlternateToDefault, +}; + +// REQUIRES: +// * async_copy must be an async copy-start instruction. +MemoryTransferDirection GetAsyncCopyDirection(const HloInstruction* async_copy, + int64_t alternate_memory_space); + +// This struct is used to track the outstanding async copy instructions and +// the remaining bytes required to be accessed. +struct OutstandingAsyncCopy { + const HloInstruction* copy_start_inst; + float remaining_bytes_to_transfer; + bool operator==(const OutstandingAsyncCopy& other) const { + return copy_start_inst == other.copy_start_inst && + remaining_bytes_to_transfer == other.remaining_bytes_to_transfer; + } +}; + // A wrapper class around runtime simulator. class RuntimeSimulator { public: - explicit RuntimeSimulator(CostAnalysis* cost_analysis) - : cost_analysis_(cost_analysis) {} - virtual ~RuntimeSimulator() = default; + explicit RuntimeSimulator(CostAnalysis* cost_analysis, + int64_t alternate_memory_space) + : cost_analysis_(cost_analysis), + alternate_memory_space_(alternate_memory_space) {} + + // This constructor is used to inject the outstanding async copy queues for + // testing purpose. + explicit RuntimeSimulator( + CostAnalysis* cost_analysis, int64_t alternate_memory_space, + const std::list& outstanding_read_default_queue, + const std::list& outstanding_write_default_queue) + : cost_analysis_(cost_analysis), + alternate_memory_space_(alternate_memory_space), + outstanding_read_default_queue_(outstanding_read_default_queue), + outstanding_write_default_queue_(outstanding_write_default_queue) {} + + ~RuntimeSimulator() = default; + // This function is used to predict the effectiveness of the memory space // assignment solution. Specifically, it returns the estimated execution time // (in seconds) of the HLO module for the given memory space assignment (i.e., @@ -36,9 +77,39 @@ class RuntimeSimulator { float ComputeEstimatedElapsedTime(const HloLiveRange& hlo_live_range, const AllocationSequence& allocations); + // This is an auxiliary function for simulating the execution + // time for executing a copy-done instruction. It returns the + // elapsed time (in seconds) for executing the copy-done instruction. + // + // This function also updates the passed in queues as we complete async copies + // during the simulation. + // + // We simulate the shared bandwidth for default-alternate memory access. + // For example, if the copy-done instruction is a default-write memory + // process, and there are outstanding default-read memory processes in the + // outstanding_read_default_queue, then we use half of the bandwidth to + // process both requests in parallel. Otherwise, we use the full bandwidth to + // process the default-write request. + float SimulateAsyncCopyDone(const HloInstruction* copy_done_instruction); + + const std::list& GetOutstandingReadDefaultQueue() const; + + const std::list& GetOutstandingWriteDefaultQueue() + const; + private: const CostAnalysis* cost_analysis_; CostAnalysis::Cache cost_analysis_cache_; + // Members used for memory model simulation + int64_t alternate_memory_space_; + std::list outstanding_read_default_queue_; + std::list outstanding_write_default_queue_; + // This function updates the queue by updating the front request with the + // processed bytes. If the request is completed (no remaining bytes to + // process), the function returns the instruction and pop it from the queue. + // Otherwise, it returns nullptr. + const HloInstruction* RemoveBytesFromQueueIfNotEmpty( + std::list& async_copy_queue, float processed_bytes); }; } // namespace memory_space_assignment } // namespace xla diff --git a/xla/service/memory_space_assignment/simulator_test.cc b/xla/service/memory_space_assignment/simulator_test.cc index 7207ebbbdf670a..57f152af12a105 100644 --- a/xla/service/memory_space_assignment/simulator_test.cc +++ b/xla/service/memory_space_assignment/simulator_test.cc @@ -16,13 +16,18 @@ limitations under the License. #include "xla/service/memory_space_assignment/simulator.h" #include +#include +#include #include +#include +#include #include #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_cost_analysis.h" @@ -42,7 +47,11 @@ using memory_space_assignment::CostAnalysis; using memory_space_assignment::CostAnalysisOptions; using memory_space_assignment::RuntimeSimulator; +using ::testing::ElementsAreArray; +using ::testing::IsEmpty; + constexpr int64_t kPointerSize = 8; +constexpr int64_t kAlternateMemorySpace = 1; int64_t ShapeSize(const Shape& shape) { return ShapeUtil::ByteSizeOf(shape, kPointerSize); @@ -50,24 +59,26 @@ int64_t ShapeSize(const Shape& shape) { class MemorySpaceAssignmentSimulatorTest : public HloTestBase { protected: - absl::Status Initialize(const HloModule* module) { + absl::Status Initialize(absl::string_view hlo_string) { + TF_ASSIGN_OR_RETURN(module_, ParseAndReturnVerifiedModule(hlo_string)); HloCostAnalysis::Options tpu_device_options; tpu_device_options.shape_size = ShapeSize; // Assume 1 FLOP per second for testing. tpu_device_options.set_flops_per_second(1); + // Assume 1 byte per second for testing. + tpu_device_options.set_bytes_per_second(1); hlo_cost_analysis_ = std::make_unique(tpu_device_options); TF_RETURN_IF_ERROR( - module->entry_computation()->Accept(hlo_cost_analysis_.get())); + module_->entry_computation()->Accept(hlo_cost_analysis_.get())); hlo_cost_analysis_costs_ = std::make_unique( *hlo_cost_analysis_); CostAnalysisOptions _options; TF_ASSIGN_OR_RETURN( cost_analysis_, - CostAnalysis::Create(*hlo_cost_analysis_costs_, _options, *module)); - runtime_simulator_ = - std::make_unique( - cost_analysis_.get()); + CostAnalysis::Create(*hlo_cost_analysis_costs_, _options, *module_)); + runtime_simulator_ = std::make_unique( + cost_analysis_.get(), kAlternateMemorySpace); return absl::OkStatus(); } std::unique_ptr hlo_cost_analysis_; @@ -75,6 +86,7 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { hlo_cost_analysis_costs_; std::unique_ptr cost_analysis_; std::unique_ptr runtime_simulator_; + std::unique_ptr module_; }; TEST_F(MemorySpaceAssignmentSimulatorTest, SingleLayerNestedLoop) { @@ -103,15 +115,13 @@ TEST_F(MemorySpaceAssignmentSimulatorTest, SingleLayerNestedLoop) { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK(Initialize(module.get())); - + TF_ASSERT_OK(Initialize(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, - HloAliasAnalysis::Run(module.get())); - TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range, - HloLiveRange::Run(module->schedule(), *alias_analysis, - module->entry_computation())); + HloAliasAnalysis::Run(module_.get())); + TF_ASSERT_OK_AND_ASSIGN( + auto hlo_live_range, + HloLiveRange::Run(module_->schedule(), *alias_analysis, + module_->entry_computation())); // Since the HLO does not contain memory access, pass an empty allocation // sequence for test. @@ -124,5 +134,164 @@ TEST_F(MemorySpaceAssignmentSimulatorTest, SingleLayerNestedLoop) { expected_elapsed_time); } +class SimulateAsyncCopyDoneTest : public MemorySpaceAssignmentSimulatorTest { + protected: + absl::Status Initialize(absl::string_view hlo_string) { + TF_RETURN_IF_ERROR( + MemorySpaceAssignmentSimulatorTest::Initialize(hlo_string)); + for (const HloInstruction* inst : + module_->entry_computation()->instructions()) { + instruction_map_[inst->name()] = inst; + if (inst->name() == "copy-start.1") { + outstanding_read_default_queue_.push_back( + memory_space_assignment::OutstandingAsyncCopy{inst, 512}); + } else if (inst->name() == "copy-start.2") { + outstanding_write_default_queue_.push_back( + memory_space_assignment::OutstandingAsyncCopy{inst, 128}); + } + } + runtime_simulator_ = std::make_unique( + cost_analysis_.get(), kAlternateMemorySpace, + outstanding_read_default_queue_, outstanding_write_default_queue_); + return absl::OkStatus(); + } + std::map instruction_map_; + std::list + outstanding_read_default_queue_; + std::list + outstanding_write_default_queue_; +}; + +TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyAlreadyCompleted) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[128] parameter(0) + copy-start.1 = (f32[128]{0:S(1)}, f32[128], u32[]) copy-start(param_0) + ROOT copy-done.1 = f32[128]{0:S(1)} copy-done(copy-start.1) + } + )"; + + TF_ASSERT_OK(Initialize(hlo_string)); + + const HloInstruction* copy_done_inst = instruction_map_["copy-done.1"]; + // Process the copy-start.1 + runtime_simulator_->SimulateAsyncCopyDone(copy_done_inst); + + // There should be no request in the read/write queues. + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), IsEmpty()); + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); + // The function should return 0 for requests that are already completed. + float elapsed_time_for_completed_copy = + runtime_simulator_->SimulateAsyncCopyDone(copy_done_inst); + EXPECT_EQ(elapsed_time_for_completed_copy, 0); + // There should be no request in the read/write queues. + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), IsEmpty()); + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); +} + +TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyFullBandwidth) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[128] parameter(0) + copy-start.1 = (f32[128]{0:S(1)}, f32[128], u32[]) copy-start(param_0) + ROOT copy-done.1 = f32[128]{0:S(1)} copy-done(copy-start.1) + } + )"; + + TF_ASSERT_OK(Initialize(hlo_string)); + const HloInstruction* copy_done_inst = instruction_map_["copy-done.1"]; + + // The elapsed time for copy-done.1 is 128 * 4 / 1 = 512. + float copy_done_elapsed_time = + runtime_simulator_->SimulateAsyncCopyDone(copy_done_inst); + EXPECT_EQ(copy_done_elapsed_time, 512); + + // There should be no request in the read/write queues. + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), IsEmpty()); + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); +} + +TEST_F(SimulateAsyncCopyDoneTest, AsyncCopySharedBandwidth) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[128] parameter(0) + param_1 = f32[32]{0:S(1)} parameter(1) + copy-start.1 = (f32[128]{0:S(1)}, f32[128], u32[]) copy-start(param_0) + copy-start.2 = (f32[32], f32[32]{0:S(1)}, u32[]) copy-start(param_1) + copy-done.2 = f32[32] copy-done(copy-start.2) + ROOT copy-done.1 = f32[128]{0:S(1)} copy-done(copy-start.1) + } + )"; + + TF_ASSERT_OK(Initialize(hlo_string)); + + const HloInstruction* copy_start_1_inst = instruction_map_["copy-start.1"]; + const HloInstruction* copy_done_2_inst = instruction_map_["copy-done.2"]; + + // The copy-start.2 needs to share bandwidth with copy-start.1. Thus, it can + // only use half bandwidth to access default memory. Thus, the elapsed time is + // 32 * 4 / 0.5 = 256 + float copy_done_2_elapsed_time = + runtime_simulator_->SimulateAsyncCopyDone(copy_done_2_inst); + EXPECT_EQ(copy_done_2_elapsed_time, 256); + + // The only write request (copy-start.2) should be completed. + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); + + // The read request has (128-32)*4 bytes left to process. + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ + copy_start_1_inst, 384}})); +} + +TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyTransferPartialProcess) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[128] parameter(0) + param_1 = f32[32]{0:S(1)} parameter(1) + copy-start.1 = (f32[128]{0:S(1)}, f32[128], u32[]) copy-start(param_0) + copy-start.2 = (f32[32], f32[32]{0:S(1)}, u32[]) copy-start(param_1) + copy-done.2 = f32[32] copy-done(copy-start.2) + ROOT copy-done.1 = f32[128]{0:S(1)} copy-done(copy-start.1) + } + )"; + + TF_ASSERT_OK(Initialize(hlo_string)); + + const HloInstruction* copy_start_1_inst = instruction_map_["copy-start.1"]; + const HloInstruction* copy_done_1_inst = instruction_map_["copy-done.1"]; + const HloInstruction* copy_done_2_inst = instruction_map_["copy-done.2"]; + + // Execute copy-done.2. + float copy_done_2_elapsed_time = + runtime_simulator_->SimulateAsyncCopyDone(copy_done_2_inst); + // For copy-done.2, it requires to transfer 32*4 bytes + // default-write request. At the same time, there is a 128*4 bytes + // default-read request in the queue for copy-start.1. So the + // elapsed time for copy-done.2 is 32*4 / (0.5*1) = 256. + EXPECT_EQ(copy_done_2_elapsed_time, 256); + // In parallel with copy-done.2, copy-start.1 is also being processed. + // So the remaining bytes should be 128*4 - 32*4 = 384. + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ + copy_start_1_inst, 384}})); + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); + + // Execute copy-done.1. + float copy_done_1_elapsed_time = + runtime_simulator_->SimulateAsyncCopyDone(copy_done_1_inst); + // The copy-done.1 is the only request in the read-queue, and there is no + // request in the write-queue. Thus, it can use the full bandwidth. The + // elapsed time is 384 / 1 = 384. + EXPECT_EQ(copy_done_1_elapsed_time, 384); + // No request should be in the queue. + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), IsEmpty()); + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); +} + } // namespace } // namespace xla From 2f136860ae191334ac811a34b82cf6168338fe4a Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Wed, 24 Jul 2024 00:51:40 -0700 Subject: [PATCH 103/376] Add support for Two-GPU tests in `xla_test`s. PiperOrigin-RevId: 655463442 --- xla/tests/build_defs.bzl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xla/tests/build_defs.bzl b/xla/tests/build_defs.bzl index b793340666506c..f4219e7cbf40f8 100644 --- a/xla/tests/build_defs.bzl +++ b/xla/tests/build_defs.bzl @@ -66,7 +66,14 @@ def prepare_nvidia_gpu_backend_data(backends, disabled_backends, backend_tags, b all_tags = new_backend_tags[gpu_backend] requires_gpu = [t for t in all_tags if t.startswith("requires-gpu-")] requires_sm, only = None, False + num_gpus = None for tag in requires_gpu: + if ":" in tag: # Multi-GPU tests are suffixed with colon and number of GPUs. + tag, suffix = tag.split(":") # Remove the suffix from the tag for further parsing. + parsed_num_gpus = int(suffix) + if num_gpus and num_gpus != parsed_num_gpus: + fail("Inconsistent number of GPUs: %d vs %d" % (num_gpus, parsed_num_gpus)) + num_gpus = parsed_num_gpus if tag.startswith("requires-gpu-sm"): version = tag.split("-")[2][2:] sm = (int(version[:-1]), int(version[-1])) @@ -84,6 +91,8 @@ def prepare_nvidia_gpu_backend_data(backends, disabled_backends, backend_tags, b else: sm_major, sm_minor = sm_requirements[gpu_backend] sm_tag = "requires-gpu-nvidia" if sm_major == 0 else "requires-gpu-sm%s%s-only" % (sm_major, sm_minor) + if num_gpus: + sm_tag += ":%d" % num_gpus new_backend_tags[gpu_backend] = [t for t in all_tags if t not in requires_gpu] new_backend_tags[gpu_backend].append(sm_tag) From 7a89946c49bd3e2f64912cd4cc53ef19da1e206d Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Wed, 24 Jul 2024 02:03:32 -0700 Subject: [PATCH 104/376] [XLA:GPU] Skip GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas test. PiperOrigin-RevId: 655481099 --- xla/service/gpu/gpu_compiler_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index ff766823c8fe2c..d95b411c7ac700 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -391,6 +391,7 @@ ENTRY main { TEST_F(GpuCompilerTest, GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) { + GTEST_SKIP() << "TODO(b/354864068): Test fails in OSS stack on A100-80."; auto cc = backend() .default_stream_executor() ->GetDeviceDescription() From 504d909ca8e15116b8076e23b92228cb1774f7ef Mon Sep 17 00:00:00 2001 From: Leo Heinsaar Date: Wed, 24 Jul 2024 02:05:50 -0700 Subject: [PATCH 105/376] [XLA:CPU] Add runtime check for whether `batch-norm-grad` is rewritten In the current runtime, emitting for op `batch-norm-grad` is not supported by design. The op is expected to be rewritten by another HLO pass before ever reaching the emit phase. This CL adds a runtime check for whether this op was actually rewritten and returns an explicit message if it wasn't. Also includes a new unit test covering existing and new functionality: batch_norm_grad_test.cc. PiperOrigin-RevId: 655481789 --- xla/service/cpu/ir_emitter.cc | 4 ++ xla/service/cpu/ir_emitter.h | 1 + xla/service/cpu/thunk_emitter.cc | 8 +++ xla/service/cpu/thunk_emitter.h | 3 ++ xla/tests/BUILD | 16 ++++++ xla/tests/batch_norm_grad_test.cc | 83 +++++++++++++++++++++++++++++++ 6 files changed, 115 insertions(+) create mode 100644 xla/tests/batch_norm_grad_test.cc diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 183479ba0fa991..6a87d52b747346 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -3676,6 +3676,10 @@ absl::Status IrEmitter::HandleAfterAll(HloInstruction* after_all) { return absl::OkStatus(); } +absl::Status IrEmitter::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + return Unimplemented("BatchNormGrad should be rewritten for CPU."); +} + absl::Status IrEmitter::HandleGetDimensionSize(HloInstruction* get_size) { return Unimplemented("GetDimensionSize should be rewritten for CPU."); } diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index 37da729c88be4e..38aa81916db2a5 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -253,6 +253,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, absl::Status HandleRng(HloInstruction* rng) override; absl::Status HandleRngBitGenerator(HloInstruction* rng) override; absl::Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override; + absl::Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; absl::Status FinishVisit(HloInstruction* root) override; absl::Status Preprocess(HloInstruction* hlo) override; diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 4ae71bff91f88a..4d01a776bb06d7 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -181,6 +181,9 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kSetDimensionSize: return EmitSetDimensionSizeThunk(instruction); + case HloOpcode::kBatchNormGrad: + return EmitBatchNormGradThunk(instruction); + // Simple HLO instructions lowered to elemental host kernels (plain loops // behind the HostKernel API). case HloOpcode::kAbs: @@ -523,6 +526,11 @@ absl::StatusOr ThunkEmitter::EmitSetDimensionSizeThunk( return Unimplemented("SetDimensionSize should be rewritten for CPU."); } +absl::StatusOr ThunkEmitter::EmitBatchNormGradThunk( + const HloInstruction* instruction) { + return Unimplemented("BatchNormGrad should be rewritten for CPU."); +} + absl::StatusOr ThunkEmitter::EmitConvolutionThunk( const HloInstruction* instruction) { // NOTE: The following code (along with TODOs and comments) partially diff --git a/xla/service/cpu/thunk_emitter.h b/xla/service/cpu/thunk_emitter.h index a829b4c159328e..4de2f7f7a78e42 100644 --- a/xla/service/cpu/thunk_emitter.h +++ b/xla/service/cpu/thunk_emitter.h @@ -87,6 +87,9 @@ class ThunkEmitter { absl::StatusOr EmitSetDimensionSizeThunk( const HloInstruction* instruction); + absl::StatusOr EmitBatchNormGradThunk( + const HloInstruction* instruction); + absl::StatusOr EmitConvolutionThunk( const HloInstruction* instruction); diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 6463beff342d19..3d5d9692e2014f 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -3215,6 +3215,22 @@ xla_test( ], ) +xla_test( + name = "batch_norm_grad_test", + srcs = ["batch_norm_grad_test.cc"], + tags = ["test_xla_cpu_thunks"], + deps = [ + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep + "//xla:literal", + "//xla:literal_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/status", + "@tsl//tsl/platform:statusor", + ], +) + bzl_library( name = "plugin_bzl", srcs = ["plugin.bzl"], diff --git a/xla/tests/batch_norm_grad_test.cc b/xla/tests/batch_norm_grad_test.cc new file mode 100644 index 00000000000000..0bff1da41b90fd --- /dev/null +++ b/xla/tests/batch_norm_grad_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "xla/literal_util.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +const char* const kModuleStr = R"( + HloModule BatchNormGrad + ENTRY BatchNormGrad.v6 { + input = f32[2,2] parameter(0) + scale = f32[2] parameter(1) + mean = f32[2] parameter(2) + variance = f32[2] parameter(3) + grad_output = f32[2,2] parameter(4) + ROOT batch-norm-grad = (f32[2,2]{1,0}, f32[2]{0}, f32[2]{0}) + batch-norm-grad(input, scale, mean, variance, grad_output), epsilon=0, feature_index=1 + } + )"; + +class BatchNormGradTest : public HloTestBase {}; + +TEST_F(BatchNormGradTest, CorrectComputation) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + // Define input parameters + auto input = LiteralUtil::CreateR2({{1.2, 2.1}, {1.3, 2.4}}); + auto scale = LiteralUtil::CreateR1({1.0, 1.0}); + auto mean = LiteralUtil::CreateR1({0.0, 0.0}); + auto variance = LiteralUtil::CreateR1({1.0, 1.0}); + auto grad_output = LiteralUtil::CreateR2({{1.0, 1.0}, {1.0, 1.0}}); + + TF_ASSERT_OK_AND_ASSIGN( + auto result, Execute(std::move(module), + {&input, &scale, &mean, &variance, &grad_output})); + + auto expected_input_grad = + LiteralUtil::CreateR2({{-1.5, -4.725}, {-1.625, -5.4}}); + auto expected_scale_grad = LiteralUtil::CreateR1({2.5, 4.5}); + auto expected_mean_grad = LiteralUtil::CreateR1({2.0, 2.0}); + + EXPECT_EQ(result, + LiteralUtil::MakeTuple({&expected_input_grad, &expected_scale_grad, + &expected_mean_grad})); +} + +TEST_F(BatchNormGradTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( + DISABLED_ON_TPU(ReturnsErrorWhenHloPassesDisabled)))) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto status_or_result = + Execute(std::move(module), {}, /*run_hlo_passes=*/false); + EXPECT_EQ(status_or_result.status().code(), absl::StatusCode::kUnimplemented); + EXPECT_THAT( + status_or_result.status().message(), + ::testing::HasSubstr("BatchNormGrad should be rewritten for CPU")); +} + +} // namespace +} // namespace xla From 3117de327b696a6c6431b38e4c733908051bfe43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Wed, 24 Jul 2024 03:12:37 -0700 Subject: [PATCH 106/376] [XLA:CPU] Add runtime check if `stochastic-convert` was decomposed. In the current runtime, `stochastic-convert` op emitting is not supported by design, normally it is rewritten by `StochasticConvertDecomposer` HLO pass and never reaches the emit phase. This CL adds a runtime check if this op was actually rewritten and returns an explicit message if it wasn't (instead of cryptic message that `kStochasticConvert` opcode was not handled in elemental IR emitter). Also added fundamental tests for stochastic convert op for both runtimes - current and thunks. PiperOrigin-RevId: 655497237 --- xla/service/cpu/ir_emitter.cc | 6 ++- xla/service/cpu/ir_emitter.h | 3 +- xla/service/cpu/thunk_emitter.cc | 8 ++++ xla/service/cpu/thunk_emitter.h | 3 ++ xla/tests/BUILD | 19 ++++++++ xla/tests/stochastic_convert_test.cc | 68 ++++++++++++++++++++++++++++ 6 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 xla/tests/stochastic_convert_test.cc diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 6a87d52b747346..30a2e5d1ebf26f 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -3684,7 +3684,7 @@ absl::Status IrEmitter::HandleGetDimensionSize(HloInstruction* get_size) { return Unimplemented("GetDimensionSize should be rewritten for CPU."); } -absl::Status IrEmitter::HandleSetDimensionSize(HloInstruction* get_size) { +absl::Status IrEmitter::HandleSetDimensionSize(HloInstruction* set_size) { return Unimplemented("SetDimensionSize should be rewritten for CPU."); } @@ -3721,6 +3721,10 @@ absl::Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) { return absl::OkStatus(); } +absl::Status IrEmitter::HandleStochasticConvert(HloInstruction* instruction) { + return Unimplemented("StochasticConvert should be decomposed for CPU."); +} + absl::Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index 38aa81916db2a5..03fbfc72d33ec9 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -246,7 +246,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, absl::Status HandleScatter(HloInstruction* scatter) override; absl::Status HandleAfterAll(HloInstruction* after_all) override; absl::Status HandleGetDimensionSize(HloInstruction* get_size) override; - absl::Status HandleSetDimensionSize(HloInstruction* get_size) override; + absl::Status HandleSetDimensionSize(HloInstruction* set_size) override; absl::Status HandleAddDependency(HloInstruction* add_dependency) override; absl::Status HandlePartitionId(HloInstruction* hlo) override; absl::Status HandleReplicaId(HloInstruction* hlo) override; @@ -254,6 +254,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, absl::Status HandleRngBitGenerator(HloInstruction* rng) override; absl::Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override; absl::Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + absl::Status HandleStochasticConvert(HloInstruction* instruction) override; absl::Status FinishVisit(HloInstruction* root) override; absl::Status Preprocess(HloInstruction* hlo) override; diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 4d01a776bb06d7..34074a77252c37 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -292,6 +292,9 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kRngGetAndUpdateState: return EmitRngGetAndUpdateStateThunk(instruction); + case HloOpcode::kStochasticConvert: + return EmitStochasticConvertThunk(instruction); + case HloOpcode::kInfeed: return EmitInfeedThunk(instruction); @@ -644,6 +647,11 @@ absl::StatusOr ThunkEmitter::EmitRngGetAndUpdateStateThunk( ThunkInfo(instruction), state_buffer, rng_state->delta()); } +absl::StatusOr ThunkEmitter::EmitStochasticConvertThunk( + const HloInstruction* instruction) { + return Unimplemented("StochasticConvert should be decomposed for CPU."); +} + absl::StatusOr ThunkEmitter::EmitInfeedThunk( const HloInstruction* instruction) { auto* infeed = Cast(instruction); diff --git a/xla/service/cpu/thunk_emitter.h b/xla/service/cpu/thunk_emitter.h index 4de2f7f7a78e42..dcc6eaf484ec1b 100644 --- a/xla/service/cpu/thunk_emitter.h +++ b/xla/service/cpu/thunk_emitter.h @@ -115,6 +115,9 @@ class ThunkEmitter { absl::StatusOr EmitRngGetAndUpdateStateThunk( const HloInstruction* instruction); + absl::StatusOr EmitStochasticConvertThunk( + const HloInstruction* instruction); + absl::StatusOr EmitInfeedThunk( const HloInstruction* instruction); diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 3d5d9692e2014f..d43d12989bb175 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2110,6 +2110,25 @@ xla_test( ], ) +xla_test( + name = "stochastic_convert_test", + srcs = ["stochastic_convert_test.cc"], + backends = ["cpu"], + tags = ["test_xla_cpu_thunks"], + deps = [ + ":hlo_test_base", + ":test_utils", + "//xla:error_spec", + "//xla:literal", + "//xla:shape_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], diff --git a/xla/tests/stochastic_convert_test.cc b/xla/tests/stochastic_convert_test.cc new file mode 100644 index 00000000000000..9aa1f023850347 --- /dev/null +++ b/xla/tests/stochastic_convert_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/error_spec.h" +#include "xla/literal.h" +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" +#include "xla/tests/test_utils.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +using StochasticConvertTest = HloTestBase; +const char* const kModuleStr = R"( + HloModule stochastic-convert + + ENTRY entry { + %arg_param.1 = f32[65536]{0} parameter(0) + %random_param.2 = u32[65536]{0} parameter(1) + ROOT %stochastic-convert.3 = s32[65536]{0} stochastic-convert( + f32[65536]{0} %arg_param.1, u32[65536]{0} %random_param.2) + } +)"; + +XLA_TEST_F(StochasticConvertTest, CorrectComputation) { + EXPECT_TRUE(RunAndCompare(kModuleStr, ErrorSpec{0.001})); +} + +TEST_F(StochasticConvertTest, ReturnsErrorWhenHloPassesDisabled) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto arg0_shape = ShapeUtil::MakeShape(F32, {65536}); + auto arg0 = MakeFakeLiteral(arg0_shape).value(); + + auto arg1_shape = ShapeUtil::MakeShape(U32, {65536}); + auto arg1 = MakeFakeLiteral(arg1_shape).value(); + + auto status_or_result = + Execute(std::move(module), {&arg0, &arg1}, /*run_hlo_passes=*/false); + EXPECT_EQ(status_or_result.status().code(), absl::StatusCode::kUnimplemented); + EXPECT_THAT( + status_or_result.status().message(), + ::testing::HasSubstr("StochasticConvert should be decomposed for CPU")); +} + +} // namespace +} // namespace xla From a1ef42d58c3472c8abf697f352136eec831ba709 Mon Sep 17 00:00:00 2001 From: Leo Heinsaar Date: Wed, 24 Jul 2024 05:04:19 -0700 Subject: [PATCH 107/376] [XLA:CPU] Add runtime check for whether `batch-norm-training` is rewritten In the current runtime, emitting for op `batch-norm-training` is not supported by design. The op is expected to be rewritten by another HLO pass before ever reaching the emit phase. This CL adds a runtime check for whether this op was actually rewritten and returns an explicit message if it wasn't. Also includes a new unit test covering existing and new functionality: batch_norm_training_test.cc. PiperOrigin-RevId: 655521839 --- xla/service/cpu/ir_emitter.cc | 5 ++ xla/service/cpu/ir_emitter.h | 2 + xla/service/cpu/thunk_emitter.cc | 7 ++ xla/service/cpu/thunk_emitter.h | 3 + xla/tests/BUILD | 16 +++++ xla/tests/batch_norm_training_test.cc | 96 +++++++++++++++++++++++++++ 6 files changed, 129 insertions(+) create mode 100644 xla/tests/batch_norm_training_test.cc diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 30a2e5d1ebf26f..9929215b3c3a3f 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -3680,6 +3680,11 @@ absl::Status IrEmitter::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { return Unimplemented("BatchNormGrad should be rewritten for CPU."); } +absl::Status IrEmitter::HandleBatchNormTraining( + HloInstruction* batch_norm_training) { + return Unimplemented("BatchNormTraining should be rewritten for CPU."); +} + absl::Status IrEmitter::HandleGetDimensionSize(HloInstruction* get_size) { return Unimplemented("GetDimensionSize should be rewritten for CPU."); } diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index 03fbfc72d33ec9..b4003b1ea227da 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -254,6 +254,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, absl::Status HandleRngBitGenerator(HloInstruction* rng) override; absl::Status HandleRngGetAndUpdateState(HloInstruction* rng_state) override; absl::Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + absl::Status HandleBatchNormTraining( + HloInstruction* batch_norm_training) override; absl::Status HandleStochasticConvert(HloInstruction* instruction) override; absl::Status FinishVisit(HloInstruction* root) override; diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 34074a77252c37..b8194d82fd6f96 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -183,6 +183,8 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kBatchNormGrad: return EmitBatchNormGradThunk(instruction); + case HloOpcode::kBatchNormTraining: + return EmitBatchNormTrainingThunk(instruction); // Simple HLO instructions lowered to elemental host kernels (plain loops // behind the HostKernel API). @@ -534,6 +536,11 @@ absl::StatusOr ThunkEmitter::EmitBatchNormGradThunk( return Unimplemented("BatchNormGrad should be rewritten for CPU."); } +absl::StatusOr ThunkEmitter::EmitBatchNormTrainingThunk( + const HloInstruction* instruction) { + return Unimplemented("BatchNormTraining should be rewritten for CPU."); +} + absl::StatusOr ThunkEmitter::EmitConvolutionThunk( const HloInstruction* instruction) { // NOTE: The following code (along with TODOs and comments) partially diff --git a/xla/service/cpu/thunk_emitter.h b/xla/service/cpu/thunk_emitter.h index dcc6eaf484ec1b..605b87578b6fad 100644 --- a/xla/service/cpu/thunk_emitter.h +++ b/xla/service/cpu/thunk_emitter.h @@ -90,6 +90,9 @@ class ThunkEmitter { absl::StatusOr EmitBatchNormGradThunk( const HloInstruction* instruction); + absl::StatusOr EmitBatchNormTrainingThunk( + const HloInstruction* instruction); + absl::StatusOr EmitConvolutionThunk( const HloInstruction* instruction); diff --git a/xla/tests/BUILD b/xla/tests/BUILD index d43d12989bb175..3345e5a92bc2e4 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -3250,6 +3250,22 @@ xla_test( ], ) +xla_test( + name = "batch_norm_training_test", + srcs = ["batch_norm_training_test.cc"], + tags = ["test_xla_cpu_thunks"], + deps = [ + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep + "//xla:literal", + "//xla:literal_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/status", + "@tsl//tsl/platform:statusor", + ], +) + bzl_library( name = "plugin_bzl", srcs = ["plugin.bzl"], diff --git a/xla/tests/batch_norm_training_test.cc b/xla/tests/batch_norm_training_test.cc new file mode 100644 index 00000000000000..581a47090cfa01 --- /dev/null +++ b/xla/tests/batch_norm_training_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "xla/literal_util.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +const char* const kModuleStr = R"( +HloModule module +ENTRY entry { + %input = f32[2,1] parameter(0) + %scale = f32[1] parameter(1) + %offset = f32[1] parameter(2) + ROOT %batch-norm-training = (f32[2,1], f32[1], f32[1]) + batch-norm-training(f32[2,1] %input, f32[1] %scale, f32[1] %offset), + epsilon=0.001, feature_index=1 +} +)"; + +class BatchNormTrainingTest : public HloTestBase {}; + +TEST_F(BatchNormTrainingTest, CorrectComputation) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto input = LiteralUtil::CreateR2({{1.0}, {2.0}}); + auto scale = LiteralUtil::CreateR1({0.5}); + auto offset = LiteralUtil::CreateR1({0.1}); + + TF_ASSERT_OK_AND_ASSIGN( + auto result, Execute(std::move(module), {&input, &scale, &offset})); + + // Decompose result tuple + auto result_tuple = result.DecomposeTuple(); + + auto expected_output = + LiteralUtil::CreateR2({{-0.399003029}, {0.599003}}); + auto expected_scale = LiteralUtil::CreateR1({1.5}); + auto expected_mean = LiteralUtil::CreateR1({0.25}); + + const float tolerance = 1e-5; // for floating-point comparison + + // Compare each element using EXPECT_NEAR instead of EXPECT_EQ to avoid + // floating-point comparison issues, otherwise the test will be flaky. + for (int i = 0; i < expected_output.element_count(); ++i) { + EXPECT_NEAR(result_tuple[0].data()[i], + expected_output.data()[i], tolerance); + } + + for (int i = 0; i < expected_scale.element_count(); ++i) { + EXPECT_NEAR(result_tuple[1].data()[i], + expected_scale.data()[i], tolerance); + } + + for (int i = 0; i < expected_mean.element_count(); ++i) { + EXPECT_NEAR(result_tuple[2].data()[i], + expected_mean.data()[i], tolerance); + } +} + +TEST_F(BatchNormTrainingTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( + DISABLED_ON_TPU(ReturnsErrorWhenHloPassesDisabled)))) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + auto status_or_result = + Execute(std::move(module), {}, /*run_hlo_passes=*/false); + EXPECT_EQ(status_or_result.status().code(), absl::StatusCode::kUnimplemented); + EXPECT_THAT( + status_or_result.status().message(), + ::testing::HasSubstr("BatchNormTraining should be rewritten for CPU")); +} + +} // namespace +} // namespace xla From 749b8645ffe9f456690a9b9b7713f60320611779 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 24 Jul 2024 05:23:24 -0700 Subject: [PATCH 108/376] Add an additional guard for ReduceOfBatchDot simplification. This simplification creates non-canonical dots (more than one contracting dimension). So guard it behind supports_non_canonical_dots option. PiperOrigin-RevId: 655526166 --- xla/service/algebraic_simplifier.cc | 3 ++- xla/service/algebraic_simplifier_test.cc | 29 ++++++++++++++++++++++ xla/service/gpu/nvptx_compiler.cc | 1 + xla/service/gpu/tests/gemm_rewrite_test.cc | 29 ++++++++++++++++++++++ 4 files changed, 61 insertions(+), 1 deletion(-) diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index 473a5c948b2514..a60e3576a67944 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -7742,7 +7742,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were // batch dimensions of the dot. The transformation supports reducing other // dimensions as well. - if (options_.enable_dot_strength_reduction() && + if (options_.supports_non_canonical_dots() && + options_.enable_dot_strength_reduction() && Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) && Match(reduce->to_apply()->root_instruction(), m::AddAnyOrder(m::Parameter(0), m::Parameter(1))) && diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index 2880c77ca0a775..03bde1910e4892 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -9435,6 +9435,35 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) { GmockMatch(m::Dot(m::Parameter(0), m::Parameter(1)))); } +// Same test as above, but with the option supports_non_canonical_dots set to +// false. +TEST_F(AlgebraicSimplifierTest, + ReduceOfBatchDotToContractingDimensionDisabled) { + const char* kModuleStr = R"( + HloModule m + a { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[] add(p0, p1) + } + test { + p0 = f32[32,8,5,6] parameter(0) + p1 = f32[8,32,6,7] parameter(1) + d = f32[32,8,5,7] dot(p0, p1), + lhs_batch_dims={0,1}, + rhs_batch_dims={1,0}, + rhs_contracting_dims={2}, + lhs_contracting_dims={3} + c = f32[] constant(0) + ROOT r = f32[8,5,7] reduce(d,c), dimensions={0}, to_apply=a + } + )"; + AlgebraicSimplifierOptions options = default_options_; + options.set_supports_non_canonical_dots(false); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).value()); +} + TEST_F(AlgebraicSimplifierTest, ReduceAddIsCommutative) { const char* kModuleStr = R"( HloModule m diff --git a/xla/service/gpu/nvptx_compiler.cc b/xla/service/gpu/nvptx_compiler.cc index a962da924d6a54..40eff87caf8c06 100644 --- a/xla/service/gpu/nvptx_compiler.cc +++ b/xla/service/gpu/nvptx_compiler.cc @@ -217,6 +217,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( AlgebraicSimplifierOptions algsimp_options = GetAlgebraicSimplifierOptions(hlo_module->config()); + algsimp_options.set_supports_non_canonical_dots(false); algsimp_options.set_enable_conv_operand_swap(false); algsimp_options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(algsimp_options, diff --git a/xla/service/gpu/tests/gemm_rewrite_test.cc b/xla/service/gpu/tests/gemm_rewrite_test.cc index 2c851e2c98b32d..0b7bcb3f58227c 100644 --- a/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -8177,6 +8177,35 @@ ENTRY main.10 { )"); } +TEST_F(GemmRewriteTest, ReduceOfBatchDot) { + absl::string_view hlo_string = + R"( +HloModule test + +region_5.50 { + Arg_0.51 = f32[] parameter(0) + Arg_1.52 = f32[] parameter(1) + ROOT add.53 = f32[] add(Arg_0.51, Arg_1.52) +} + +ENTRY main { + p0 = bf16[3,32,3,13]{3,2,1,0} parameter(0) + p1 = bf16[3,32,3,64]{3,2,1,0} parameter(1) + dot.95 = bf16[3,3,13,64]{3,2,1,0} dot(p0, p1), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}, operand_precision={highest,highest} + transpose.96 = bf16[3,64,3,13]{1,3,2,0} transpose(dot.95), dimensions={0,3,1,2} + convert.101 = f32[3,64,3,13]{1,3,2,0} convert(transpose.96) + constant.66 = f32[] constant(0.0) + ROOT reduce.102 = f32[3,64,13]{2,1,0} reduce(convert.101, constant.66), dimensions={2}, to_apply=region_5.50 +} +)"; + // Make sure the dot is lowered to a custom call. There is an algebraic + // simplifier simplification which could turn the dot into a non-canonical dot + // late in the pipeline, which will make it unsupported by the GemmRewriter. + MatchOptimizedHlo(hlo_string, R"( + // CHECK: custom_call_target="__cublas$gemm" + )"); +} + class GemmRewriteAllocationTest : public GpuCodegenTest { public: void CheckNumberOfAllocations(const std::string& hlo, From 5b0e626589b08de741c14c2568cf1b942ace55ba Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 24 Jul 2024 05:24:03 -0700 Subject: [PATCH 109/376] [XLA:GPU] Make FusionInfoCache thread-safe and use in Priority Fusion. The cache helps to reduce compile time of big modules. PiperOrigin-RevId: 655526320 --- xla/service/gpu/BUILD | 2 +- xla/service/gpu/gpu_fusible.cc | 56 ++++++++++++++++++++---------- xla/service/gpu/gpu_fusible.h | 26 ++++++++------ xla/service/gpu/priority_fusion.cc | 11 ++++-- 4 files changed, 61 insertions(+), 34 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 18eed308d08ef0..c485caac626c9e 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2084,7 +2084,6 @@ cc_library( ":gpu_fusible", ":hlo_fusion_analysis", ":hlo_traversal", - ":triton_fusion_analysis", "//xla:debug_options_flags", "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -4356,6 +4355,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", ], ) diff --git a/xla/service/gpu/gpu_fusible.cc b/xla/service/gpu/gpu_fusible.cc index b445f35feb11e8..19ca91b00de663 100644 --- a/xla/service/gpu/gpu_fusible.cc +++ b/xla/service/gpu/gpu_fusible.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -621,23 +622,31 @@ static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) { return 0; } -int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache) { - if (!cache) { - return SharedMemoryUsageNoCache(instr); +int64_t FusionInfoCache::GetSharedMemoryUsage(const HloInstruction& instr) { + { + absl::MutexLock lock(&mutex_); + auto it = shared_memory_usage_.find(&instr); + if (it != shared_memory_usage_.end()) { + return it->second; + } } // nb: Users are only expected to call cache.Invalidate() on top-level // instructions, not instructions inside fusion nodes. Therefore we can only // cache top-level instructions; it would not be valid to pass the cache to // SharedMemoryUsageNoCache and use the cache *within* the fusion. - auto it_and_inserted = cache->shared_memory_usage.emplace(&instr, -1); - auto it = it_and_inserted.first; - auto inserted = it_and_inserted.second; + int64_t shared_memory_usage = SharedMemoryUsageNoCache(instr); + + absl::MutexLock lock(&mutex_); + shared_memory_usage_.emplace(&instr, shared_memory_usage); + return shared_memory_usage; +} - if (inserted) { - it->second = SharedMemoryUsageNoCache(instr); +int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache) { + if (!cache) { + return SharedMemoryUsageNoCache(instr); } - return it->second; + return cache->GetSharedMemoryUsage(instr); } // Codegen'ing unnested reductions requires a lot of registers, so a MOF @@ -661,24 +670,33 @@ static int64_t NumUnnestedReductionsNoCache(const HloInstruction& instr) { return 0; } -static int64_t NumUnnestedReductions(const HloInstruction& instr, - FusionInfoCache* cache) { - if (!cache) { - return NumUnnestedReductionsNoCache(instr); +int64_t FusionInfoCache::GetNumUnnestedReductions(const HloInstruction& instr) { + { + absl::MutexLock lock(&mutex_); + auto it = num_unnested_reductions_.find(&instr); + if (it != num_unnested_reductions_.end()) { + return it->second; + } } // nb: Users are only expected to call cache.Invalidate() on top-level // instructions, not instructions inside fusion nodes. Therefore we can only // cache top-level instructions; it would not be valid to pass the cache to // NumUnnestedReductionsNoCache and use the cache *within* the fusion. - auto it_and_inserted = cache->num_unnested_reductions.emplace(&instr, -1); - auto it = it_and_inserted.first; - auto inserted = it_and_inserted.second; + int64_t num_unnested_reductions = NumUnnestedReductionsNoCache(instr); + + absl::MutexLock lock(&mutex_); + num_unnested_reductions_.emplace(&instr, num_unnested_reductions); + return num_unnested_reductions; +} - if (inserted) { - it->second = NumUnnestedReductionsNoCache(instr); +static int64_t NumUnnestedReductions(const HloInstruction& instr, + FusionInfoCache* cache) { + if (!cache) { + return NumUnnestedReductionsNoCache(instr); } - return it->second; + + return cache->GetNumUnnestedReductions(instr); } // This function limits the maximum number of operands to a fusion, and the diff --git a/xla/service/gpu/gpu_fusible.h b/xla/service/gpu/gpu_fusible.h index 7bc120b7574529..185c440603a6b2 100644 --- a/xla/service/gpu/gpu_fusible.h +++ b/xla/service/gpu/gpu_fusible.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_traversal.h" @@ -50,19 +51,26 @@ bool IsExpensiveToRepeat(const HloInstruction& instr); // those properties n^2 times. // // Invariant: After modifying or removing a fusion node, call Invalidate(node). -struct FusionInfoCache { +class FusionInfoCache { public: // Must be called after modifying or removing a fusion node (or other node // that's part of this cache). void Invalidate(const HloInstruction* instr) { - shared_memory_usage.erase(instr); - num_unnested_reductions.erase(instr); + shared_memory_usage_.erase(instr); + num_unnested_reductions_.erase(instr); } - // The rest of the members of this class are for internal use within - // gpu_fusible. You shouldn't need to use them yourself. - absl::flat_hash_map shared_memory_usage; - absl::flat_hash_map num_unnested_reductions; + // Returns expected shared memory usage of a given instruction in bytes. + int64_t GetSharedMemoryUsage(const HloInstruction& instr); + + // Returns the number of unnested reductions in the instruction output. + int64_t GetNumUnnestedReductions(const HloInstruction& instr); + + private: + absl::Mutex mutex_; + + absl::flat_hash_map shared_memory_usage_; + absl::flat_hash_map num_unnested_reductions_; }; // Returns the computations within `module` whose instructions can still be @@ -72,10 +80,6 @@ std::vector GetFusibleComputations( const HloModule& module, const absl::flat_hash_set& execution_threads); -// Returns projected shared memory usage of a given instruction in bytes. -int64_t SharedMemoryUsage(const HloInstruction& instr, - FusionInfoCache* cache = nullptr); - inline constexpr int64_t MaxOperandsAndOutputsPerFusion() { return 96; } // Whether the op transposes the physical data layout. Fusing such ops may lead diff --git a/xla/service/gpu/priority_fusion.cc b/xla/service/gpu/priority_fusion.cc index e01372e9e7439f..ec5ff230eef7bf 100644 --- a/xla/service/gpu/priority_fusion.cc +++ b/xla/service/gpu/priority_fusion.cc @@ -289,6 +289,7 @@ class GpuPriorityFusionQueue { gpu_performance_model_cache_.Invalidate(*instruction); fusion_analysis_cache_.Invalidate(*instruction); + fusion_info_cache_.Invalidate(instruction); } // Updates data for the new fusion instruction and its users and operands. @@ -512,9 +513,9 @@ class GpuPriorityFusionQueue { // Avoid cases where we'd create a fusion that hit limitations in ptxas. // Would be nice to model this with cost instead. - if (auto fits_budget = - FusionFitsInBudget(*consumer, *producer, *device_info_, - /*is_consumer_producer_fusion=*/true); + if (auto fits_budget = FusionFitsInBudget( + *consumer, *producer, *device_info_, + /*is_consumer_producer_fusion=*/true, &fusion_info_cache_); !fits_budget) { return fits_budget; } @@ -640,6 +641,10 @@ class GpuPriorityFusionQueue { GpuPerformanceModelCache gpu_performance_model_cache_; + // Cache for `FusionFitsInBudget` to avoid recomputing expensive properties + // like shared memory usage or number of unnested reductions of fusion nodes. + FusionInfoCache fusion_info_cache_; + bool triton_softmax_priority_fusion_enabled_; bool dump_fusion_visualization_; From 93c31eb48eb2228f889d4e848c5821a78b21035b Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 24 Jul 2024 06:09:13 -0700 Subject: [PATCH 110/376] PR #14310: [GPU] Cleanup handling of determinism settings. Imported from GitHub PR https://github.com/openxla/xla/pull/14310 TF_CUDNN_DETERMINISTIC is TF-specific and will be handled there. Merge all checks of determinism flags in RequireDeterminism(). Set NumericOptions.require_determinism using it. Copybara import of the project: -- 1351c0fe1784d538977590a1bc6ace4056a8bd0d by Ilia Sergachev : [GPU] Remove TF_CUDNN_DETERMINISTIC; cleanup handling of determinism settings. TF_CUDNN_DETERMINISTIC is TF-specific and will be handled there. Merge all checks of determinism flags in RequireDeterminism(). Set NumericOptions.require_determinism using it. Merging this change closes #14310 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/14310 from openxla:cleanup_determinism_handling 1351c0fe1784d538977590a1bc6ace4056a8bd0d PiperOrigin-RevId: 655537398 --- xla/service/gpu/BUILD | 1 + xla/service/gpu/conv_algorithm_picker.cc | 14 +++------ xla/service/gpu/cudnn_workspace_rewriter.cc | 7 ++--- xla/service/gpu/determinism_test.cc | 1 + xla/service/gpu/fusions/BUILD | 1 + xla/service/gpu/fusions/custom.cc | 6 ++-- xla/service/gpu/gemm_algorithm_picker.cc | 3 +- xla/service/gpu/gpu_compiler.cc | 3 +- xla/service/gpu/ir_emitter_unnested.cc | 16 ++++------ xla/service/gpu/stream_executor_util.cc | 12 ++------ xla/service/gpu/tests/gpu_fused_mha_test.cc | 2 +- xla/stream_executor/cuda/cuda_dnn.cc | 33 ++++++--------------- 12 files changed, 31 insertions(+), 68 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index c485caac626c9e..6d1109cd14abd4 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -323,6 +323,7 @@ cc_library( ":launch_dimensions", ":matmul_utils", ":parallel_loop_emitter", + ":stream_executor_util", ":triton_call", "//xla:autotuning_proto_cc", "//xla:literal", diff --git a/xla/service/gpu/conv_algorithm_picker.cc b/xla/service/gpu/conv_algorithm_picker.cc index fc93de8b744698..40bbb7a8477680 100644 --- a/xla/service/gpu/conv_algorithm_picker.cc +++ b/xla/service/gpu/conv_algorithm_picker.cc @@ -819,9 +819,6 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( const bool cudnn_frontend_enabled = debug_options.xla_gpu_enable_cudnn_frontend(); - const bool deterministic_ops = - debug_options.xla_gpu_deterministic_ops() || - debug_options.xla_gpu_exclude_nondeterministic_ops(); bool allow_tf32 = true; // TODO(b/284371623): Properly set allow_tf32 even if instr==nullptr, which is // the case when running an AOT compiled executable with runtime autotuning. @@ -830,7 +827,8 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( instr->precision_config().operand_precision(), [](int precision) { return precision <= PrecisionConfig::HIGH; }); } - const se::NumericOptions numeric_options{deterministic_ops, allow_tf32}; + const se::NumericOptions numeric_options{ + RequireDeterminism(instr->GetModule()->config()), allow_tf32}; // Use the first algorithm that's supported as reference. There isn't a // particular reason to use it, as any algorithm suffices. It doesn't make @@ -929,15 +927,11 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( XLA_SCOPED_LOGGING_TIMER(absl::StrCat( "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString())); - const DebugOptions& debug_options = - instr->GetModule()->config().debug_options(); - const bool deterministic_ops = - debug_options.xla_gpu_deterministic_ops() || - debug_options.xla_gpu_exclude_nondeterministic_ops(); const bool allow_tf32 = absl::c_all_of( instr->precision_config().operand_precision(), [](int precision) { return precision <= PrecisionConfig::HIGH; }); - const se::NumericOptions numeric_options{deterministic_ops, allow_tf32}; + const se::NumericOptions numeric_options{ + RequireDeterminism(instr->GetModule()->config()), allow_tf32}; se::StreamExecutor* stream_exec = config_.GetExecutor(); const auto device_ordinal = stream_exec->device_ordinal(); diff --git a/xla/service/gpu/cudnn_workspace_rewriter.cc b/xla/service/gpu/cudnn_workspace_rewriter.cc index 265703efc41b21..387a3f4d3aec8e 100644 --- a/xla/service/gpu/cudnn_workspace_rewriter.cc +++ b/xla/service/gpu/cudnn_workspace_rewriter.cc @@ -174,11 +174,8 @@ absl::StatusOr HloCustomCallToCuDnnGraph( TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, AsCudnnFmhaMaskKind(config.mask_type())); - const DebugOptions& debug_options = - custom_call->GetModule()->config().debug_options(); - bool force_deterministic = - debug_options.xla_gpu_deterministic_ops() || - debug_options.xla_gpu_exclude_nondeterministic_ops(); + const bool force_deterministic = + RequireDeterminism(custom_call->GetModule()->config()); // set the correct force_deterministic attribute here config.set_force_deterministic(force_deterministic); TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); diff --git a/xla/service/gpu/determinism_test.cc b/xla/service/gpu/determinism_test.cc index 02f5ba86a64254..93c5b1591f110c 100644 --- a/xla/service/gpu/determinism_test.cc +++ b/xla/service/gpu/determinism_test.cc @@ -155,6 +155,7 @@ TEST_F(DeterminismTest, ExcludingNonDeterministicOpsDoesNotDisableAutotuning) { #endif // TENSORFLOW_USE_ROCM debug_options_.set_xla_gpu_cublas_fallback(false); + ASSERT_TRUE(debug_options_.xla_gpu_exclude_nondeterministic_ops()); ASSERT_FALSE(debug_options_.xla_gpu_deterministic_ops()); AutotunerUtil::ClearAutotuneResults(); // The default config is not used when autotuning is on. diff --git a/xla/service/gpu/fusions/BUILD b/xla/service/gpu/fusions/BUILD index 06eface0660a57..d72b156c033754 100644 --- a/xla/service/gpu/fusions/BUILD +++ b/xla/service/gpu/fusions/BUILD @@ -140,6 +140,7 @@ cc_library( "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:kernel_arguments", "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/kernels:custom_kernel", "//xla/service/gpu/kernels:custom_kernel_fusion", "//xla/service/gpu/runtime:custom_call_thunk", diff --git a/xla/service/gpu/fusions/custom.cc b/xla/service/gpu/fusions/custom.cc index 2b7c151db7db93..7e6ccb6743e5bf 100644 --- a/xla/service/gpu/fusions/custom.cc +++ b/xla/service/gpu/fusions/custom.cc @@ -60,6 +60,7 @@ limitations under the License. #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/hlo.pb.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" @@ -411,9 +412,8 @@ absl::StatusOr EmitGemm( "operand/result"); } - bool deterministic_ops = - ir_emitter_context.debug_options().xla_gpu_deterministic_ops() || - ir_emitter_context.debug_options().xla_gpu_exclude_nondeterministic_ops(); + const bool deterministic_ops = + RequireDeterminism(fusion.GetModule()->config()); TF_ASSIGN_OR_RETURN( GemmConfig config, diff --git a/xla/service/gpu/gemm_algorithm_picker.cc b/xla/service/gpu/gemm_algorithm_picker.cc index aac4aa2f21eb20..a2de14cb2e4cf7 100644 --- a/xla/service/gpu/gemm_algorithm_picker.cc +++ b/xla/service/gpu/gemm_algorithm_picker.cc @@ -112,8 +112,7 @@ class GemmAutotuner { TF_ASSIGN_OR_RETURN(stream_, autotune_config_.GetStream()); const DebugOptions& debug_options = gemm->GetModule()->config().debug_options(); - deterministic_ops_ = debug_options.xla_gpu_deterministic_ops() || - debug_options.xla_gpu_exclude_nondeterministic_ops(); + deterministic_ops_ = RequireDeterminism(gemm->GetModule()->config()); solutions_limit_ = debug_options.xla_gpu_autotune_max_solutions(); TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm)); diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 3973e92b144a9b..7780e9b315a1b5 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -665,8 +665,7 @@ absl::Status RunOptimizationPasses( // handle it. pipeline.AddPass(); - if (debug_options.xla_gpu_deterministic_ops() || - debug_options.xla_gpu_exclude_nondeterministic_ops()) { + if (RequireDeterminism(hlo_module->config())) { // Scatter can be indeterministic if indices are not unique or a non // associative combiner function is used. Eliminate these Scatter ops. pipeline.AddPass( diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index c102feaf76cdf4..9082eaca99b815 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -145,6 +145,7 @@ limitations under the License. #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/runtime/wait_for_streams_thunk.h" #include "xla/service/gpu/runtime/while_thunk.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/triton_call.h" #include "xla/service/llvm_ir/buffer_assignment_util.h" #include "xla/service/llvm_ir/ir_array.h" @@ -649,17 +650,13 @@ absl::Status IrEmitterUnnested::EmitGemmThunk( TF_ASSIGN_OR_RETURN(workspace, GetAllocationSliceForHlo(instr, {1})); } - bool deterministic_ops = - ir_emitter_context_->debug_options().xla_gpu_deterministic_ops() || - ir_emitter_context_->debug_options() - .xla_gpu_exclude_nondeterministic_ops(); - TF_ASSIGN_OR_RETURN( GemmConfig config, GemmConfig::For(static_cast(instr))); auto thunk = std::make_unique( Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(config), a, b, - c, workspace, deterministic_ops); + c, workspace, + RequireDeterminism(ir_emitter_context_->hlo_module().config())); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } @@ -1761,13 +1758,10 @@ absl::Status IrEmitterUnnested::EmitAsyncCustomCallStart( absl::Status IrEmitterUnnested::AssertNonDeterminismIsOkay( const std::string& op_name) { - if (ir_emitter_context_->debug_options().xla_gpu_deterministic_ops() || - ir_emitter_context_->debug_options() - .xla_gpu_exclude_nondeterministic_ops()) { + if (RequireDeterminism(ir_emitter_context_->hlo_module().config())) { return Unimplemented( "HLO instruction %s does not have a deterministic implementation, " - "but run-to-run determinism is required by --xla_gpu_deterministic_ops " - "or --xla_gpu_exclude_nondeterministic_ops.", + "but run-to-run determinism is required.", op_name); } return absl::OkStatus(); diff --git a/xla/service/gpu/stream_executor_util.cc b/xla/service/gpu/stream_executor_util.cc index 6508bf9e9cf874..cde9b554bd504d 100644 --- a/xla/service/gpu/stream_executor_util.cc +++ b/xla/service/gpu/stream_executor_util.cc @@ -621,16 +621,8 @@ absl::StatusOr GetDNNDataTypeFromPrimitiveType( } bool RequireDeterminism(const HloModuleConfig& config) { - static bool require_cudnn_determinism = [] { - // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var. - bool cudnn_deterministic = false; - TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC", - /*default_val=*/false, - &cudnn_deterministic)); - return cudnn_deterministic; - }(); - return require_cudnn_determinism || - config.debug_options().xla_gpu_deterministic_ops(); + return config.debug_options().xla_gpu_deterministic_ops() || + config.debug_options().xla_gpu_exclude_nondeterministic_ops(); } namespace { diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index 03f3448f5f15ab..639cf511875f43 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -155,7 +155,7 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_cudnn_fmha(true); if (force_deterministic) { - debug_options.set_xla_gpu_deterministic_ops(true); + debug_options.set_xla_gpu_exclude_nondeterministic_ops(true); } reference_module->mutable_config().set_debug_options(debug_options); const Literal first_run_result = diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 3898d1e3ea002a..bbc6a6dc2cca79 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -948,21 +948,6 @@ bool BatchnormSpatialPersistentEnabled() { return is_enabled; } -bool RequireCudnnDeterminism(const NumericOptions& numeric_options) { - static bool cudnn_deterministic_env_var = [] { - // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var. - bool cudnn_deterministic = false; - TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC", - /*default_val=*/false, - &cudnn_deterministic)); - return cudnn_deterministic; - }(); - bool require_determinism = - cudnn_deterministic_env_var || numeric_options.require_determinism; - VLOG(5) << "RequireCudnnDeterminism: " << require_determinism; - return require_determinism; -} - // A helper function to decide whether to force the default conv algorithm. bool ConvUseDefaultAlgorithm() { static bool use_default = [] { @@ -1100,7 +1085,7 @@ class CudnnPoolingDescriptor { std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), &CheckedNarrowing); bool propagate_nans = pooling_descriptor.propagate_nans(); - const auto cudnn_max_pooling_mode = RequireCudnnDeterminism(numeric_options) + const auto cudnn_max_pooling_mode = numeric_options.require_determinism ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX; CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor( @@ -6211,16 +6196,16 @@ absl::Status CreateOpRunners( bool need_side_input, const NumericOptions& numeric_options) { cudnn_frontend::EngineConfigList filtered_configs; const bool disable_winograd = !CudnnEnvVar::IsEnabled(); - const bool disable_nondeterminism = RequireCudnnDeterminism(numeric_options); const bool disable_tensor_core = !IsTensorMathEnabled(stream, input_type, numeric_options.allow_tf32); auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool { return GenericEngineFilter(engine_config, disable_winograd, - disable_nondeterminism, disable_tensor_core); + numeric_options.require_determinism, + disable_tensor_core); }; VLOG(4) << "Filtering engine configs with disable_winograd=" << disable_winograd - << ", disable_nondeterminism=" << disable_nondeterminism + << ", disable_nondeterminism=" << numeric_options.require_determinism << ", disable_tensor_core=" << disable_tensor_core; std::array heur_mode = {use_fallback ? "heuristics_fallback" @@ -6299,7 +6284,7 @@ absl::Status CreateOpRunners( std::move(runner_or).value())); // We will use the first working plan when determinism is required. - if (RequireCudnnDeterminism(numeric_options)) { + if (numeric_options.require_determinism) { break; } } @@ -7359,7 +7344,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( if (CudnnEnvVar::IsEnabled()) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); } - if (!RequireCudnnDeterminism(numeric_options)) { + if (numeric_options.require_determinism) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0); } @@ -7399,7 +7384,7 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( if (CudnnEnvVar::IsEnabled()) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED); } - if (!RequireCudnnDeterminism(numeric_options)) { + if (!numeric_options.require_determinism) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0); algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3); } @@ -7937,7 +7922,7 @@ absl::Status CudnnSupport::DoPrepareForCtcLoss( // Try running with `algo`, if successful then pick it. The // non-deterministic algorithm is first and thus preferentially picked // when determinism is not required. - auto algo = RequireCudnnDeterminism(numeric_options) + auto algo = numeric_options.require_determinism ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC; cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize( @@ -7949,7 +7934,7 @@ absl::Status CudnnSupport::DoPrepareForCtcLoss( /*algo=*/algo, /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(), /*sizeInBytes=*/&workspace_size_in_bytes); - if (RequireCudnnDeterminism(numeric_options)) { + if (numeric_options.require_determinism) { RETURN_IF_CUDNN_ERROR(status); } From 8391d5f3047f124aebc2eb5ab3a3d9b04bcdffcd Mon Sep 17 00:00:00 2001 From: Vladyslav Tsilytskyi Date: Wed, 24 Jul 2024 06:34:36 -0700 Subject: [PATCH 111/376] [xla:cpu] Allow monkey patching for compute_function_ This allows an easy IrEmitter code reuse from IrEmitter2 PiperOrigin-RevId: 655544047 --- xla/service/cpu/BUILD | 4 +-- xla/service/cpu/ir_emitter.cc | 46 +++++++++++++++++++--------------- xla/service/cpu/ir_emitter.h | 32 +++++++++++++++++++++-- xla/service/cpu/ir_function.cc | 31 +++++++++++++++++++---- xla/service/cpu/ir_function.h | 13 ++++++++++ 5 files changed, 97 insertions(+), 29 deletions(-) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 085f12582a9b61..4cbc0e9e5439fe 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -624,8 +624,10 @@ cc_library( hdrs = ["ir_emitter2.h"], deps = [ ":backend_config_proto_cc", + ":dot_op_emitter", ":elemental_math_emitter", ":ir_emitter", + ":ir_function", ":parallel_loop_emitter", ":shape_partition", "//xla:cpu_function_runtime", @@ -635,13 +637,11 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:elemental_ir_emitter", - "//xla/service/cpu:dot_op_emitter", "//xla/service/llvm_ir:dynamic_update_slice_util", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", - "//xla/service/llvm_ir:tuple_ops", "//xla/stream_executor:launch_dim", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 9929215b3c3a3f..8b4c3a204eaf52 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -183,8 +183,15 @@ IrEmitter::IrEmitter(mlir::MLIRContext* mlir_context, TF_CHECK_OK(s) << "Should have failed buffer assignment."; } +IrEmitter::~IrEmitter() { + if (!compute_function_.empty()) { + LOG(WARNING) << "Compute function stack is not empty: " + << compute_function_.size(); + } +}; + void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) { - llvm::Argument* out_parameter = compute_function_->result_arg(); + llvm::Argument* out_parameter = compute_function()->result_arg(); llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction()); const Shape& return_shape = computation->root_instruction()->shape(); @@ -267,7 +274,7 @@ absl::StatusOr IrEmitter::EmitComputation( b()->setFastMathFlags(flags); TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); - llvm::Function* ir_function = compute_function_->function(); + llvm::Function* ir_function = compute_function()->function(); for (llvm::Attribute::AttrKind attr : function_attributes) { ir_function->addFnAttr(attr); @@ -291,8 +298,8 @@ absl::StatusOr IrEmitter::EmitComputation( EmitThreadLocalFunctionEpilogue(computation); } - // Destructor for compute_function_ terminates the LLVM function definition. - compute_function_.reset(); + // Destructor for compute_function() terminates the LLVM function definition. + PopComputeFunction(); computation_root_allocation_ = BufferAllocation::Slice(); computation_parameter_allocations_.clear(); return ir_function; @@ -306,9 +313,8 @@ void IrEmitter::InitializeIrFunction(const std::string& function_name) { is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; // Create and initialize new IrFunction. - compute_function_ = - std::make_unique(function_name, linkage, hlo_module_config_, - module_, b(), num_dynamic_loop_bounds_); + compute_function_.emplace(function_name, linkage, hlo_module_config_, module_, + b(), num_dynamic_loop_bounds_); } absl::Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { @@ -1760,7 +1766,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( PrimitiveType element_type, unsigned element_count) { int vector_register_size_in_elements = target_machine_features_.vector_register_byte_size( - *compute_function_->function()) / + *compute_function()->function()) / ShapeUtil::ByteSizeOfPrimitiveType(element_type); ShardedVectorType sharded_vector_type; @@ -1929,7 +1935,7 @@ absl::StatusOr IrEmitter::EmitVectorizedReduce( int vector_register_size_in_elements = target_machine_features_.vector_register_byte_size( - *compute_function_->function()) / + *compute_function()->function()) / ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); if (vector_register_size_in_elements == 0) { // Either we don't know the vector register width for the target or the @@ -3150,7 +3156,7 @@ absl::Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Terminates the current block with a branch to a while header. llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), IrName(xla_while, "header"), - compute_function_->function()); + compute_function()->function()); Br(header_bb); b()->SetInsertPoint(header_bb); @@ -3166,7 +3172,7 @@ absl::Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"), - compute_function_->function()); + compute_function()->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), IrName(xla_while, "exit")); CondBr(while_predicate, body_bb, exit_bb); @@ -3181,7 +3187,7 @@ absl::Status IrEmitter::HandleWhile(HloInstruction* xla_while) { Br(header_bb); // Adds the exit block to the function and sets the insert point there. - llvm::Function* llvm_fn = compute_function_->function(); + llvm::Function* llvm_fn = compute_function()->function(); llvm_fn->insert(llvm_fn->end(), exit_bb); b()->SetInsertPoint(exit_bb); @@ -3967,23 +3973,23 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { } llvm::Value* IrEmitter::GetProfileCountersArgument() { - return compute_function_->profile_counters_arg(); + return compute_function()->profile_counters_arg(); } llvm::Value* IrEmitter::GetStatusArgument() { - return compute_function_->status_arg(); + return compute_function()->status_arg(); } llvm::Value* IrEmitter::GetBufferTableArgument() { - return compute_function_->buffer_table_arg(); + return compute_function()->buffer_table_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { - return compute_function_->exec_run_options_arg(); + return compute_function()->exec_run_options_arg(); } llvm::BasicBlock* IrEmitter::GetReturnBlock() { - return compute_function_->return_block(); + return compute_function()->return_block(); } void IrEmitter::EmitEarlyReturnIfErrorStatus() { @@ -4011,7 +4017,7 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( // // Where Param is the actual element type of the underlying buffer (for // example, float for an XLA F32 element type). - llvm::Value* params = compute_function_->parameters_arg(); + llvm::Value* params = compute_function()->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP( params, b()->getPtrTy(), param_number, b()); llvm::LoadInst* param_address_untyped = @@ -4031,7 +4037,7 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( const Shape& shape = assigned_buffers.begin()->first->shape(); std::pair key = { - compute_function_->function(), slice}; + compute_function()->function(), slice}; auto buf_it = thread_local_buffers_.find(key); if (buf_it == thread_local_buffers_.end()) { llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( @@ -4137,7 +4143,7 @@ absl::Status IrEmitter::EmitTargetElementLoop( if (ShouldEmitParallelLoopFor(*target_op)) { // Emit code to read dynamic loop bounds from compute function argument. std::vector> dynamic_loop_bounds = - compute_function_->GetDynamicLoopBounds(); + compute_function()->GetDynamicLoopBounds(); // Emit parallel loop with dynamic loop bounds for most-major dimensions. TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, &dynamic_loop_bounds, b()) diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index b4003b1ea227da..4566a7ccdf2f35 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -101,7 +102,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, computation_transitively_contains_custom_call, const TargetMachineFeatures* target_machine, bool emit_code_for_msan); - ~IrEmitter() override = default; + ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR // function. @@ -136,6 +137,33 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::IRBuilder<>* builder() { return current_builder_; } const llvm::IRBuilder<>* builder() const { return current_builder_; } + IrFunction* compute_function() { return &compute_function_.top(); } + + // Used by IrEmitter + void PushComputeFunction(const std::string& function_name, + llvm::Function::LinkageTypes linkage, + const HloModuleConfig& module_config, + llvm::Module* llvm_module, + int64_t num_dynamic_loop_bounds) { + compute_function_.emplace(function_name, linkage, module_config, + llvm_module, b(), num_dynamic_loop_bounds); + } + + // Used by IrEmitter2 + void PushComputeFunction(std::shared_ptr> b, + llvm::Module* llvm_module, + int64_t num_dynamic_loop_bounds, + llvm::Function* function, + llvm::Value* dynamic_loop_bounds_arg, + llvm::BasicBlock* return_block) { + b->SetInsertPoint(llvm::BasicBlock::Create(llvm_module->getContext(), + "insertion_point", function)); + compute_function_.emplace(b.get(), llvm_module, num_dynamic_loop_bounds, + function, dynamic_loop_bounds_arg, return_block); + } + + void PopComputeFunction() { compute_function_.pop(); } + // Emit an LLVM global variable for every constant buffer allocation. absl::Status EmitConstantGlobals(); @@ -591,7 +619,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // The current builder to use for IR emission. This is either `main_builder_` // or a temporary builder that replaces it. llvm::IRBuilder<>* current_builder_; - std::unique_ptr compute_function_; + std::stack compute_function_; mlir::MLIRContext* mlir_context_; bool allow_reassociation_; diff --git a/xla/service/cpu/ir_function.cc b/xla/service/cpu/ir_function.cc index 1c6180604b6d88..2537bf2e72c526 100644 --- a/xla/service/cpu/ir_function.cc +++ b/xla/service/cpu/ir_function.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Value.h" #include "xla/service/cpu/cpu_runtime.h" #include "xla/service/cpu/shape_partition.h" #include "xla/service/llvm_ir/llvm_util.h" @@ -50,6 +52,25 @@ IrFunction::IrFunction(const std::string& function_name, Initialize(function_name, linkage, module_config); } +IrFunction::IrFunction(llvm::IRBuilder<>* b, llvm::Module* llvm_module, + int64_t num_dynamic_loop_bounds, + llvm::Function* function, + llvm::Value* dynamic_loop_bounds_arg, + llvm::BasicBlock* return_block) + : b_(b), + llvm_module_(llvm_module), + caller_insert_point_guard_(*b), + num_dynamic_loop_bounds_(num_dynamic_loop_bounds), + function_(function), + result_arg_(nullptr), + exec_run_options_arg_(nullptr), + parameters_arg_(nullptr), + buffer_table_arg_(nullptr), + dynamic_loop_bounds_arg_(dynamic_loop_bounds_arg), + profile_counters_arg_(nullptr), + status_arg_(nullptr), + return_block_(return_block) {}; + IrFunction::~IrFunction() { // Branch to function return. b_->CreateBr(return_block_); @@ -83,9 +104,9 @@ void IrFunction::Initialize(const std::string& function_name, // buffer_table: address of an array with pointers to temporary buffers and // entry computation parameters (but not to constant buffers). // - // Therefore, the generated function's signature (FunctionType) is statically - // determined - parameter unpacking is done in code generated into the - // function, rather than by a prologue dictated by the platform ABI. + // Therefore, the generated function's signature (FunctionType) is + // statically determined - parameter unpacking is done in code generated + // into the function, rather than by a prologue dictated by the platform ABI. // // /--------------\ // retval ----------> | return value | @@ -126,8 +147,8 @@ void IrFunction::Initialize(const std::string& function_name, // \---------------------------------------------/ // Even though the type of params and buffer_table is void** in the host's - // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to - // the code to use GEPs to unravel the indirection layers. + // view, in LLVM IR this is represented by i8*, similarly to void*. It's up + // to the code to use GEPs to unravel the indirection layers. llvm::FunctionType* function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), /*Params=*/ diff --git a/xla/service/cpu/ir_function.h b/xla/service/cpu/ir_function.h index a9ae4ce1a817a2..9c0d6c6fa7d637 100644 --- a/xla/service/cpu/ir_function.h +++ b/xla/service/cpu/ir_function.h @@ -56,6 +56,19 @@ class IrFunction { llvm::Function::LinkageTypes linkage, const HloModuleConfig& module_config, llvm::Module* llvm_module, llvm::IRBuilder<>* b, int64_t num_dynamic_loop_bounds); + + // Initialize an llvm::Function with existing function, created somewhere + // else, omit any extra work. + IrFunction(llvm::IRBuilder<>* b, llvm::Module* llvm_module, + int64_t num_dynamic_loop_bounds, llvm::Function* function, + // Function argument IR values. + // llvm::Argument* result_arg, llvm::Value* exec_run_options_arg, + // llvm::Value* parameters_arg, llvm::Value* buffer_table_arg, + llvm::Value* dynamic_loop_bounds_arg, + // llvm::Value* profile_counters_arg, llvm::Value* status_arg, + // Basic block containing return. + llvm::BasicBlock* return_block); + ~IrFunction(); // Emit IR to read and return the set of IR values representing the dynamic From 2ae1586dfc79be5e648ef965b7b9bbb1a0b8d7e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Wed, 24 Jul 2024 06:37:05 -0700 Subject: [PATCH 112/376] [XLA:CPU] Turn on another batch of tests for thunks runtime. PiperOrigin-RevId: 655544772 --- xla/tests/BUILD | 58 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 3345e5a92bc2e4..048c47594dbc3f 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -381,6 +381,7 @@ cc_library( xla_test( name = "bad_rng_shape_validation_test", srcs = ["bad_rng_shape_validation_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", @@ -398,6 +399,7 @@ xla_test( xla_test( name = "buffer_donation_test", srcs = ["buffer_donation_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", @@ -422,6 +424,7 @@ xla_test( "conv_depthwise_test.cc", ], shard_count = 50, + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":conv_depthwise_common", @@ -442,6 +445,7 @@ xla_test( timeout = "long", srcs = ["conv_depthwise_backprop_filter_test.cc"], shard_count = 40, + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -465,6 +469,7 @@ xla_test( "cpu", ], shard_count = 50, + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -484,6 +489,7 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", @@ -503,6 +509,7 @@ xla_test( xla_test( name = "query_inferred_shape_test", srcs = ["query_inferred_shape_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", @@ -519,6 +526,7 @@ xla_test( xla_test( name = "while_test", srcs = ["while_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -623,6 +631,7 @@ xla_test( shard_count = 30, tags = [ "optonly", + "test_xla_cpu_thunks", ], deps = [ ":client_library_test_base", @@ -647,6 +656,7 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", @@ -760,6 +770,7 @@ xla_test( xla_test( name = "deallocation_test", srcs = ["deallocation_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":test_macros_header", @@ -879,6 +890,7 @@ xla_test( xla_test( name = "fft_test", srcs = ["fft_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":test_macros_header", @@ -901,6 +913,7 @@ xla_test( }, tags = [ "optonly", + "test_xla_cpu_thunks", ], deps = [ ":hlo_test_base", @@ -957,6 +970,7 @@ xla_test( "optonly", # TODO(b/151340488): Timed out on 2020-03-12. "nozapfhahn", + "test_xla_cpu_thunks", ], deps = [ ":client_library_test_base", @@ -1000,6 +1014,7 @@ xla_test( tags = [ "nozapfhahn", "optonly", + "test_xla_cpu_thunks", ], deps = [ ":client_library_test_base", @@ -1030,7 +1045,10 @@ xla_test( name = "gather_operation_test", srcs = ["gather_operation_test.cc"], shard_count = 20, - tags = ["test_hlo_pjrt_runner"], + tags = [ + "test_hlo_pjrt_runner", + "test_xla_cpu_thunks", + ], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1049,6 +1067,7 @@ xla_test( xla_test( name = "scatter_test", srcs = ["scatter_test.cc"], + tags = ["test_xla_cpu_thunks"], # TODO(b/245550554): enable Pjrt runner for scatter test once it's fixed. deps = [ ":client_library_test_base", @@ -1078,6 +1097,7 @@ xla_test( shard_count = 20, tags = [ "optonly", + "test_xla_cpu_thunks", ], deps = [ ":client_library_test_base", @@ -1127,6 +1147,7 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -1223,6 +1244,7 @@ xla_test( tags = [ "no_rocm", "optonly", + "test_xla_cpu_thunks", ], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", @@ -1320,6 +1342,7 @@ xla_test( "cpu": ["nomsan"], }, shard_count = 30, + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -1342,6 +1365,7 @@ xla_test( timeout = "long", srcs = ["convolution_dimension_numbers_test.cc"], shard_count = 20, + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -1449,6 +1473,7 @@ xla_test( xla_test( name = "float8_test", srcs = ["float8_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", @@ -1941,6 +1966,7 @@ xla_test( xla_test( name = "pad_test", srcs = ["pad_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -1974,6 +2000,7 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -1989,6 +2016,7 @@ xla_test( name = "matrix_ops_simple_test", timeout = "long", srcs = ["matrix_ops_simple_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2289,6 +2317,7 @@ xla_test( "gpu", "cpu", ], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", @@ -2344,6 +2373,7 @@ xla_test( xla_test( name = "collective_pipeliner_execution_test", srcs = ["collective_pipeliner_execution_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", @@ -2450,6 +2480,7 @@ xla_test( "cpu", "gpu", ], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":xla_internal_test_main", @@ -2471,6 +2502,7 @@ xla_test( xla_test( name = "value_inference_test", srcs = ["value_inference_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":literal_test_util", ":test_macros_header", @@ -2500,6 +2532,7 @@ xla_test( xla_test( name = "compute_constant_test", srcs = ["compute_constant_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":literal_test_util", ":test_macros_header", @@ -2523,6 +2556,7 @@ xla_test( xla_test( name = "client_test", srcs = ["client_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2545,6 +2579,7 @@ xla_test( xla_test( name = "replay_test", srcs = ["replay_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2616,6 +2651,7 @@ xla_test( xla_test( name = "round_trip_packed_literal_test", srcs = ["round_trip_packed_literal_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2643,6 +2679,7 @@ xla_test( "gpu", "interpreter", ], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":hlo_test_base", @@ -2699,7 +2736,10 @@ xla_cc_test( ":local_client_aot_test_computation.o", ], linkstatic = 1, - tags = ["not_run:arm"], # b/341355246 + tags = [ + "not_run:arm", # b/341355246 + "test_xla_cpu_thunks", + ], deps = [ "//xla:executable_run_options", "@com_google_absl//absl/base:dynamic_annotations", @@ -2712,6 +2752,7 @@ xla_cc_test( xla_test( name = "local_client_allocation_test", srcs = ["local_client_allocation_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":literal_test_util", ":local_client_test_base", @@ -2770,6 +2811,7 @@ xla_test( # Outfeed ops are not supported on the interpreter backend. "interpreter", ], + tags = ["test_xla_cpu_thunks"], deps = [ ":local_client_test_base", ":test_macros_header", @@ -2783,6 +2825,7 @@ xla_cc_test( srcs = [ "hlo_metadata_test.cc", ], + tags = ["test_xla_cpu_thunks"], deps = [ ":local_client_test_base", "//xla:test_helpers", @@ -2797,6 +2840,7 @@ xla_cc_test( xla_test( name = "round_trip_transfer_test", srcs = ["round_trip_transfer_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":literal_test_util", @@ -2839,6 +2883,7 @@ xla_test( xla_test( name = "deep_graph_test", srcs = ["deep_graph_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":xla_internal_test_main", @@ -2849,6 +2894,7 @@ xla_test( xla_cc_test( name = "literal_test_util_test", srcs = ["literal_test_util_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":literal_test_util", "//xla:literal", @@ -2866,6 +2912,7 @@ xla_test( name = "transfer_manager_test", srcs = ["transfer_manager_test.cc"], shard_count = 50, + tags = ["test_xla_cpu_thunks"], deps = [ ":literal_test_util", ":local_client_test_base", @@ -2896,6 +2943,7 @@ xla_test( "cpu", "gpu", ], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", @@ -2928,6 +2976,7 @@ xla_test( srcs = ["test_utils_test.cc"], # There is nothing backend specific in this test, so just pick an arbitrary backend. backends = ["cpu"], + tags = ["test_xla_cpu_thunks"], deps = [ ":local_client_test_base", ":test_macros_header", @@ -2965,6 +3014,7 @@ xla_cc_test( name = "multiple_devices_on_host_test", srcs = ["multiple_devices_on_host_test.cc"], args = ["--xla_force_host_platform_device_count=4"], + tags = ["test_xla_cpu_thunks"], deps = [ ":xla_internal_test_main", # fixdeps: keep "//xla:shape_util", @@ -2985,6 +3035,7 @@ xla_test( tags = [ # Disabled in OSS until nvidia publicly releases a fixed ptxas. "no_oss", + "test_xla_cpu_thunks", ], deps = [ ":hlo_test_base", @@ -3035,6 +3086,7 @@ xla_test( tags = [ "enable_for_xla_interpreter", "optonly", + "test_xla_cpu_thunks", ], deps = [ ":client_library_test_base", @@ -3062,6 +3114,7 @@ xla_test( shard_count = 10, tags = [ "optonly", + "test_xla_cpu_thunks", ], deps = [ ":client_library_test_base", @@ -3097,6 +3150,7 @@ xla_test( xla_cc_test( name = "tile_assignment_test", srcs = ["tile_assignment_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":xla_internal_test_main", "//xla:array3d", From 808cf4a74d046da1b698108d4159bfbec806039b Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 24 Jul 2024 06:59:58 -0700 Subject: [PATCH 113/376] [XLA:GPU] Add a version of HloAnyOf for HloFusionAdaptor that iterates over instruction. We don't need to run BFS to iterate over all the instruction is a fusion adaptor. Performing a linear scan is much more efficient is cases when the order doesn't matter and we want to look through the all the instruction in the fusion. There is a use-case for BFS when we start not from the root or the order matters. Added 'Bfs' to the name of existing 'HloFindIf' and 'HloFindAny' function to indicate the difference. PiperOrigin-RevId: 655550255 --- xla/service/gpu/command_buffer_scheduling.cc | 2 +- .../gpu/dynamic_slice_fusion_rewriter.cc | 4 +- xla/service/gpu/fusions/custom.cc | 6 +- xla/service/gpu/fusions/loop.cc | 37 +++--- .../gpu/fusions/mlir/elemental_hlo_to_mlir.cc | 117 +++++++++--------- .../fusions/triton/triton_fusion_emitter.cc | 4 +- xla/service/gpu/gemm_fusion_autotuner.cc | 2 +- xla/service/gpu/gpu_fusible.cc | 2 +- xla/service/gpu/hlo_traversal.cc | 50 ++++++-- xla/service/gpu/hlo_traversal.h | 38 ++++-- xla/service/gpu/hlo_traversal_test.cc | 37 ++++-- xla/service/gpu/ir_emission_utils.cc | 4 +- xla/service/gpu/priority_fusion.cc | 2 +- 13 files changed, 182 insertions(+), 123 deletions(-) diff --git a/xla/service/gpu/command_buffer_scheduling.cc b/xla/service/gpu/command_buffer_scheduling.cc index 793beea9317a84..d81046b9534331 100644 --- a/xla/service/gpu/command_buffer_scheduling.cc +++ b/xla/service/gpu/command_buffer_scheduling.cc @@ -180,7 +180,7 @@ static bool IsCommand(const HloInstruction* hlo, auto fusion_analysis = HloFusionAnalysis::Create(fusion, &config.device_description); const HloFusionAdaptor& adaptor = fusion_analysis.fusion(); - auto custom_call_adaptor = HloFindIf( + auto custom_call_adaptor = HloBfsFindIf( adaptor.GetRoots(), adaptor, [](auto node) { return node.opcode() == HloOpcode::kCustomCall; }); const auto* custom_call = static_cast( diff --git a/xla/service/gpu/dynamic_slice_fusion_rewriter.cc b/xla/service/gpu/dynamic_slice_fusion_rewriter.cc index 9c5db58fe283fd..09192416db5217 100644 --- a/xla/service/gpu/dynamic_slice_fusion_rewriter.cc +++ b/xla/service/gpu/dynamic_slice_fusion_rewriter.cc @@ -176,7 +176,7 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { // empty: if the operand is a tuple, it might have different data flows // (i.e. 1 for each element). auto maybe_slice_instr = - HloFindIf({operand}, [&](const HloInstruction* cur) { + HloBfsFindIf({operand}, [&](const HloInstruction* cur) { // If the node is a match that has been processed, stop the traversal. if (processed_instrs.contains(cur)) return true; @@ -223,7 +223,7 @@ DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { auto traverse_hlo_and_collect = [&](HloInstruction* start) { DefUseDataflowPath maybe_sliced_user_path; bool dus_found = false; - auto maybe_dus_instr = HloFindIf( + auto maybe_dus_instr = HloBfsFindIf( {start}, [&](const HloInstruction* cur) { // If the node is a match that has been processed, stop the diff --git a/xla/service/gpu/fusions/custom.cc b/xla/service/gpu/fusions/custom.cc index 7e6ccb6743e5bf..6e1ea915631701 100644 --- a/xla/service/gpu/fusions/custom.cc +++ b/xla/service/gpu/fusions/custom.cc @@ -116,7 +116,7 @@ absl::StatusOr GetOperandSlice( /*index*/ {}); } - auto slice_adaptor = HloFindIf( + auto slice_adaptor = HloBfsFindIf( {HloInstructionAdaptor(*start, &adaptor)}, adaptor, [](HloInstructionAdaptor node) { return IsOpcodeAnyOf( @@ -302,7 +302,7 @@ absl::StatusOr GetResultSlice( } } - auto slice_adaptor = HloFindIf( + auto slice_adaptor = HloBfsFindIf( {HloInstructionAdaptor(*start, &adaptor)}, adaptor, [](auto node) { return node.opcode() == HloOpcode::kDynamicUpdateSlice; }, /*visit_operands=*/false); @@ -812,7 +812,7 @@ absl::StatusOr DynamicSliceFusion::Emit( IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { const HloFusionAdaptor& adaptor = analysis_.fusion(); - auto maybe_custom_call_adaptor = HloFindIf( + auto maybe_custom_call_adaptor = HloBfsFindIf( adaptor.GetRoots(), adaptor, [](auto node) { return node.opcode() == HloOpcode::kCustomCall; }); if (maybe_custom_call_adaptor == std::nullopt) { diff --git a/xla/service/gpu/fusions/loop.cc b/xla/service/gpu/fusions/loop.cc index e1186ba63dacb8..e9b7933b1c7895 100644 --- a/xla/service/gpu/fusions/loop.cc +++ b/xla/service/gpu/fusions/loop.cc @@ -193,25 +193,24 @@ LaunchDimensionsConfig ComputeLoopFusionConfig( int num_big_inputs; std::tie(row_vectorized, num_big_inputs) = RowVectorizationEnabled(analysis.fusion(), element_shape.rank()); - bool few_waves = !HloAnyOf( - analysis.fusion().GetRoots(), analysis.fusion(), [&](auto instr) { - if (instr.opcode() == HloOpcode::kParameter || - instr.opcode() == HloOpcode::kConstant || - HloInstruction::IsOpElementwise(instr.opcode())) { - return false; - } - if (auto broadcast = - DynCast(&instr.instruction())) { - if (broadcast->dimensions().empty() || - // More than 3 big inputs cause a speed regression. - (row_vectorized && num_big_inputs <= 3)) { - return false; - } - } - VLOG(2) << "few_waves not enabled due to: " - << instr.instruction().ToString(); - return true; - }); + bool few_waves = !HloAnyOf(analysis.fusion(), [&](auto instr) { + if (instr.opcode() == HloOpcode::kParameter || + instr.opcode() == HloOpcode::kConstant || + HloInstruction::IsOpElementwise(instr.opcode())) { + return false; + } + if (auto broadcast = + DynCast(&instr.instruction())) { + if (broadcast->dimensions().empty() || + // More than 3 big inputs cause a speed regression. + (row_vectorized && num_big_inputs <= 3)) { + return false; + } + } + VLOG(2) << "few_waves not enabled due to: " + << instr.instruction().ToString(); + return true; + }); LaunchDimensionsConfig launch_config{unroll_factor, few_waves, row_vectorized}; diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index aef5d7139bdd89..59471a3fb337ea 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -118,56 +118,56 @@ namespace scf = ::mlir::scf; // HLO opcodes that we never support. static auto& kUnsupportedOps = - *new absl::flat_hash_set{HloOpcode::kAddDependency, - HloOpcode::kAfterAll, - HloOpcode::kAllGather, - HloOpcode::kAllGatherDone, - HloOpcode::kAllGatherStart, - HloOpcode::kAllReduce, - HloOpcode::kAllReduceDone, - HloOpcode::kAllReduceStart, - HloOpcode::kAllToAll, - HloOpcode::kAsyncDone, - HloOpcode::kAsyncStart, - HloOpcode::kAsyncUpdate, - HloOpcode::kBatchNormGrad, - HloOpcode::kBatchNormInference, - HloOpcode::kBatchNormTraining, - HloOpcode::kCholesky, - HloOpcode::kCollectivePermute, - HloOpcode::kCollectivePermuteDone, - HloOpcode::kCollectivePermuteStart, - HloOpcode::kCopyDone, - HloOpcode::kCopyStart, - HloOpcode::kCustomCall, - HloOpcode::kDomain, - HloOpcode::kDynamicReshape, - HloOpcode::kFft, - HloOpcode::kFusion, - HloOpcode::kGetDimensionSize, - HloOpcode::kOptimizationBarrier, - HloOpcode::kInfeed, - HloOpcode::kOutfeed, - HloOpcode::kPartitionId, - HloOpcode::kRecv, - HloOpcode::kRecvDone, - HloOpcode::kReduceScatter, - HloOpcode::kReplicaId, - HloOpcode::kRng, - HloOpcode::kRngBitGenerator, - HloOpcode::kRngGetAndUpdateState, - HloOpcode::kScatter, - HloOpcode::kSelectAndScatter, - HloOpcode::kSend, - HloOpcode::kSendDone, - HloOpcode::kSetDimensionSize, - HloOpcode::kSort, - HloOpcode::kTopK, - HloOpcode::kTriangularSolve, - HloOpcode::kWhile, - HloOpcode::kConditional, - HloOpcode::kStochasticConvert, - HloOpcode::kCall}; + *new llvm::DenseSet{HloOpcode::kAddDependency, + HloOpcode::kAfterAll, + HloOpcode::kAllGather, + HloOpcode::kAllGatherDone, + HloOpcode::kAllGatherStart, + HloOpcode::kAllReduce, + HloOpcode::kAllReduceDone, + HloOpcode::kAllReduceStart, + HloOpcode::kAllToAll, + HloOpcode::kAsyncDone, + HloOpcode::kAsyncStart, + HloOpcode::kAsyncUpdate, + HloOpcode::kBatchNormGrad, + HloOpcode::kBatchNormInference, + HloOpcode::kBatchNormTraining, + HloOpcode::kCholesky, + HloOpcode::kCollectivePermute, + HloOpcode::kCollectivePermuteDone, + HloOpcode::kCollectivePermuteStart, + HloOpcode::kCopyDone, + HloOpcode::kCopyStart, + HloOpcode::kCustomCall, + HloOpcode::kDomain, + HloOpcode::kDynamicReshape, + HloOpcode::kFft, + HloOpcode::kFusion, + HloOpcode::kGetDimensionSize, + HloOpcode::kOptimizationBarrier, + HloOpcode::kInfeed, + HloOpcode::kOutfeed, + HloOpcode::kPartitionId, + HloOpcode::kRecv, + HloOpcode::kRecvDone, + HloOpcode::kReduceScatter, + HloOpcode::kReplicaId, + HloOpcode::kRng, + HloOpcode::kRngBitGenerator, + HloOpcode::kRngGetAndUpdateState, + HloOpcode::kScatter, + HloOpcode::kSelectAndScatter, + HloOpcode::kSend, + HloOpcode::kSendDone, + HloOpcode::kSetDimensionSize, + HloOpcode::kSort, + HloOpcode::kTopK, + HloOpcode::kTriangularSolve, + HloOpcode::kWhile, + HloOpcode::kConditional, + HloOpcode::kStochasticConvert, + HloOpcode::kCall}; bool IsUnsupportedGather(const HloInstruction* instr) { // We assume gather simplifier ran, so we don't need to support all gather @@ -1217,15 +1217,14 @@ bool IsHloConversionSupported(const HloFusionAdaptor& fusion, auto cuda_compute_capability = std::get(compute_capability); - return !HloFindIf( - fusion.GetRoots(), fusion, [=](HloInstructionAdaptor instr) { - return !absl::c_all_of(instr.instruction().called_computations(), - [&](const HloComputation* called) { - return IsHloConversionSupported( - called, compute_capability); - }) || - !IsHloOpSupported(&instr.instruction(), cuda_compute_capability); - }); + return !HloAnyOf(fusion, [=](HloInstructionAdaptor instr) { + return !absl::c_all_of(instr.instruction().called_computations(), + [&](const HloComputation* called) { + return IsHloConversionSupported( + called, compute_capability); + }) || + !IsHloOpSupported(&instr.instruction(), cuda_compute_capability); + }); } ValueRange ProvideParameter(const PartitionedComputation& computation, diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index e21483462afeea..5d3fa949307bfe 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -1842,7 +1842,7 @@ class MatMulEmitterHelper { absl::StatusOr GetMatMulLaunchDimensions( const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, const TritonGemmConfig& config) { - auto dot = HloFindIf(fusion.GetRoots(), fusion, [](auto node) { + auto dot = HloBfsFindIf(fusion.GetRoots(), fusion, [](auto node) { return node.opcode() == HloOpcode::kDot; }); TF_RET_CHECK(dot != std::nullopt); @@ -2177,7 +2177,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. bool is_unsupported_bitwidth = - HloAnyOf({dot_instr}, [&](const HloInstruction* node) { + HloBfsAnyOf({dot_instr}, [&](const HloInstruction* node) { if (node->opcode() != HloOpcode::kConvert) { return false; } diff --git a/xla/service/gpu/gemm_fusion_autotuner.cc b/xla/service/gpu/gemm_fusion_autotuner.cc index 179f7387b9b83b..8bc3650c3325df 100644 --- a/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/gemm_fusion_autotuner.cc @@ -633,7 +633,7 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { // to avoid autotuning configurations that are not supported by Triton. This // is used to restrict the values for tile_k. std::vector converts = - HloFindAll({&dot}, [&](const HloInstruction* node) { + HloBfsFindAll({&dot}, [&](const HloInstruction* node) { return node->opcode() == HloOpcode::kConvert; }); int minBitWidth = primitive_util::BitWidth(dot.shape().element_type()); diff --git a/xla/service/gpu/gpu_fusible.cc b/xla/service/gpu/gpu_fusible.cc index 19ca91b00de663..c40be168a7d44a 100644 --- a/xla/service/gpu/gpu_fusible.cc +++ b/xla/service/gpu/gpu_fusible.cc @@ -970,7 +970,7 @@ bool MayPreventVectorization(const HloFusionAdaptor& fusion) { // An empirically chosen constant: unrolling concat with a large amount of // arguments causes excessive register spilling. static constexpr int kMaxConcatArgumentsForUnrolling = 10; - return HloAnyOf(fusion.GetRoots(), fusion, [&](auto node) { + return HloAnyOf(fusion, [&](auto node) { switch (node.opcode()) { case HloOpcode::kReduceWindow: case HloOpcode::kSort: diff --git a/xla/service/gpu/hlo_traversal.cc b/xla/service/gpu/hlo_traversal.cc index 7bfdbdbeabf03d..b529997f15f962 100644 --- a/xla/service/gpu/hlo_traversal.cc +++ b/xla/service/gpu/hlo_traversal.cc @@ -129,6 +129,11 @@ class SingleInstructionFusion : public internal::HloFusionInstructionAdaptor { return {HloInstructionAdaptor{*instruction_, parent_}}; } + void ForEach( + const std::function& fn) const override { + fn(HloInstructionAdaptor{*instruction_, parent_}); + } + std::string ToString() const override { return instruction_->ToString(); } private: @@ -219,6 +224,20 @@ class HloComputationFusion : public internal::HloFusionInstructionAdaptor { return result; } + void ForEach( + const std::function& fn) const override { + for (const HloInstruction* instr : computation_->instructions()) { + // HloFusionAdaptor hides existence of parameters, tuples and gte + // instructions. + if (instr->opcode() == HloOpcode::kParameter || + instr->opcode() == HloOpcode::kTuple || + instr->opcode() == HloOpcode::kGetTupleElement) { + continue; + } + fn(HloInstructionAdaptor{*instr, parent_}); + } + } + std::string ToString() const override { return computation_->ToString(); } private: @@ -397,6 +416,13 @@ HloFusionAdaptor::MakeInstructionPostOrder() const { return result_post_order; } +void HloFusionAdaptor::ForEach( + const std::function& fn) const { + for (const auto& fusion_instruction : fusion_instructions_) { + fusion_instruction->ForEach(fn); + } +} + std::string HloFusionAdaptor::ToString() const { std::ostringstream ss; for (const auto& fusion_instruction : fusion_instructions_) { @@ -536,20 +562,20 @@ void HloBfsProducersFirstTraversal( /*visit_operands=*/false); } -bool HloAnyOf(absl::Span roots, - const HloFusionAdaptor& fusion, - const std::function& visit, - bool visit_operands) { - return HloFindIf(roots, fusion, visit, visit_operands).has_value(); +bool HloBfsAnyOf(absl::Span roots, + const HloFusionAdaptor& fusion, + const std::function& visit, + bool visit_operands) { + return HloBfsFindIf(roots, fusion, visit, visit_operands).has_value(); } -bool HloAnyOf(absl::Span roots, - const std::function& visit, - bool visit_operands) { - return HloFindIf(roots, visit, visit_operands).has_value(); +bool HloBfsAnyOf(absl::Span roots, + const std::function& visit, + bool visit_operands) { + return HloBfsFindIf(roots, visit, visit_operands).has_value(); } -std::optional HloFindIf( +std::optional HloBfsFindIf( absl::Span roots, const HloFusionAdaptor& fusion, const std::function& visit, @@ -609,7 +635,7 @@ std::vector HloFindAllImpl( return result; } -std::optional HloFindIf( +std::optional HloBfsFindIf( absl::Span roots, const std::function& visit, bool visit_operands) { @@ -621,7 +647,7 @@ std::optional HloFindIf( return result[0]; } -std::vector HloFindAll( +std::vector HloBfsFindAll( absl::Span roots, const std::function& visit, bool visit_operands) { diff --git a/xla/service/gpu/hlo_traversal.h b/xla/service/gpu/hlo_traversal.h index bee9d06af8cd42..b49d4efc9377ce 100644 --- a/xla/service/gpu/hlo_traversal.h +++ b/xla/service/gpu/hlo_traversal.h @@ -95,6 +95,8 @@ class HloFusionInstructionAdaptor { virtual const HloInstruction& FusionInstruction() const = 0; virtual absl::InlinedVector MakeInstructionPostOrder() const = 0; + virtual void ForEach( + const std::function& fn) const = 0; virtual std::string ToString() const = 0; }; @@ -108,6 +110,10 @@ class HloFusionAdaptor { absl::InlinedVector GetParameters() const; absl::InlinedVector MakeInstructionPostOrder() const; + + // Calls `fn` for each instruction in the fusion. + void ForEach(const std::function& fn) const; + std::string ToString() const; static std::unique_ptr ForInstruction( @@ -157,24 +163,24 @@ void HloBfsProducersFirstTraversal( // of `visit` for any of nodes is true. Uses the same order as // `HloBfsConsumersFirstTraversal` if `visit_operands` is true. Otherwise the // same order as `HloBfsProducersFirstTraversal` is used. -bool HloAnyOf(absl::Span roots, - const HloFusionAdaptor& fusion, - const std::function& visit, - bool visit_operands = true); +bool HloBfsAnyOf(absl::Span roots, + const HloFusionAdaptor& fusion, + const std::function& visit, + bool visit_operands = true); // Visit the HLO nodes starting from `roots`, returning true if the return value // of `visit` for any of nodes is true. If `visit_operands` is true, the // search is going towards the operands, otherwise towards the users. Doesn't // require instruction and fusion adaptors. -bool HloAnyOf(absl::Span roots, - const std::function& visit, - bool visit_operands = true); +bool HloBfsAnyOf(absl::Span roots, + const std::function& visit, + bool visit_operands = true); // Visit the HLO nodes starting from `roots`, returning the first // node for which `visit` returns true, or `nullopt` if no node matches. Uses // the same order as `HloBfsConsumersFirstTraversal` if `visit_operands` is // true. Otherwise the same order as `HloBfsProducersFirstTraversal` is used. -std::optional HloFindIf( +std::optional HloBfsFindIf( absl::Span roots, const HloFusionAdaptor& fusion, const std::function& visit, @@ -184,7 +190,7 @@ std::optional HloFindIf( // search is going towards the operands, otherwise towards the users. Returns // the first node for which `visit` returns true, or `nullopt` if no node // matches. -std::optional HloFindIf( +std::optional HloBfsFindIf( absl::Span roots, const std::function& visit, bool visit_operands = true); @@ -193,11 +199,23 @@ std::optional HloFindIf( // search is going towards the operands, otherwise towards the users. Returns // all nodes for which `visit` returns true. If no node matches, returns an // empty vector. -std::vector HloFindAll( +std::vector HloBfsFindAll( absl::Span roots, const std::function& visit, bool visit_operands = true); +// Returns true if any instruction in the fusion adaptor matches the predicate. +template +bool HloAnyOf(const HloFusionAdaptor& fusion, Pred&& pred) { + bool is_any = false; + fusion.ForEach([&](HloInstructionAdaptor node) { + if (pred(node)) { + is_any = true; + } + }); + return is_any; +} + // Find a use chain from `parent` to `root`. Empty if no chain exists. // `[parent]` if `parent` is `root`. std::vector HloFindUseChain(HloInstructionAdaptor parent, diff --git a/xla/service/gpu/hlo_traversal_test.cc b/xla/service/gpu/hlo_traversal_test.cc index d8168d2687b3d7..43c7a9e75dc04d 100644 --- a/xla/service/gpu/hlo_traversal_test.cc +++ b/xla/service/gpu/hlo_traversal_test.cc @@ -248,27 +248,44 @@ TEST_F(HloTraversalTest, FindArgumentsAfterFusion) { EXPECT_THAT(producers, ElementsAre("p0", "log")); } -TEST_F(HloTraversalTest, FindIf) { +TEST_F(HloTraversalTest, HloBfsFindIf_Found) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); auto fusion = HloFusionAdaptor::ForInstruction( module->entry_computation()->GetInstructionWithName("fusion")); - auto result = - HloFindIf(fusion->GetRoots(), *fusion, [&](HloInstructionAdaptor node) { - return node.opcode() == HloOpcode::kMultiply; - }); + auto result = HloBfsFindIf(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { + return node.opcode() == HloOpcode::kMultiply; + }); ASSERT_NE(result, std::nullopt); ASSERT_EQ(result->name(), "mul"); } -TEST_F(HloTraversalTest, NotFound) { +TEST_F(HloTraversalTest, HloBfsFindIf_NotFound) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); auto fusion = HloFusionAdaptor::ForInstruction( module->entry_computation()->GetInstructionWithName("fusion")); - auto result = HloFindIf(fusion->GetRoots(), *fusion, - [&](HloInstructionAdaptor node) { return false; }); + auto result = HloBfsFindIf(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { return false; }); ASSERT_EQ(result, std::nullopt); } +TEST_F(HloTraversalTest, HloAnyOf_Found) { + auto module = ParseAndReturnVerifiedModule(kTestModule).value(); + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); + EXPECT_TRUE(HloAnyOf(*fusion, [&](HloInstructionAdaptor node) { + return node.opcode() == HloOpcode::kMultiply; + })); +} + +TEST_F(HloTraversalTest, HloAnyOf_NotFound) { + auto module = ParseAndReturnVerifiedModule(kTestModule).value(); + auto fusion = HloFusionAdaptor::ForInstruction( + module->entry_computation()->GetInstructionWithName("fusion")); + EXPECT_FALSE( + HloAnyOf(*fusion, [&](HloInstructionAdaptor node) { return false; })); +} + TEST_F(HloTraversalTest, FindAllMultiple) { const char kConverts[] = R"( HloModule test @@ -285,7 +302,7 @@ TEST_F(HloTraversalTest, FindAllMultiple) { auto module = ParseAndReturnVerifiedModule(kConverts).value(); auto root = module->entry_computation()->GetInstructionWithName("diff"); std::vector converts = - HloFindAll({root}, [&](const HloInstruction* node) { + HloBfsFindAll({root}, [&](const HloInstruction* node) { return node->opcode() == HloOpcode::kConvert; }); @@ -309,7 +326,7 @@ TEST_F(HloTraversalTest, FindAllNotFound) { auto module = ParseAndReturnVerifiedModule(kConverts).value(); auto root = module->entry_computation()->GetInstructionWithName("diff"); std::vector converts = - HloFindAll({root}, [&](const HloInstruction* node) { + HloBfsFindAll({root}, [&](const HloInstruction* node) { return node->opcode() == HloOpcode::kAdd; }); EXPECT_THAT(converts, IsEmpty()); diff --git a/xla/service/gpu/ir_emission_utils.cc b/xla/service/gpu/ir_emission_utils.cc index 437f1c3467971b..81d05f4d1347fa 100644 --- a/xla/service/gpu/ir_emission_utils.cc +++ b/xla/service/gpu/ir_emission_utils.cc @@ -692,8 +692,8 @@ static std::optional FindNonTrivialHero( /*allowed_operand_count=*/3); }; bool visit_operands = false; - if (HloAnyOf(hero->GetUsers(), hero->parent(), is_nontrivial, - visit_operands)) { + if (HloBfsAnyOf(hero->GetUsers(), hero->parent(), is_nontrivial, + visit_operands)) { return std::nullopt; } diff --git a/xla/service/gpu/priority_fusion.cc b/xla/service/gpu/priority_fusion.cc index ec5ff230eef7bf..d57a83c20ee8d1 100644 --- a/xla/service/gpu/priority_fusion.cc +++ b/xla/service/gpu/priority_fusion.cc @@ -477,7 +477,7 @@ class GpuPriorityFusionQueue { // TODO(b/312200883): Remove this. auto contains_significant_reduce = [&](const HloInstruction* instr) { auto fusion = HloFusionAdaptor::ForInstruction(instr); - return HloAnyOf(fusion->GetRoots(), *fusion, [](auto node) { + return HloAnyOf(*fusion, [](auto node) { if (!(node.opcode() == HloOpcode::kReduce && node.shape().IsArray())) { return false; } From d41d742d2b23094c956e2f9a25441773e9057561 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 24 Jul 2024 07:40:06 -0700 Subject: [PATCH 114/376] [xla:cpu] Make WhileThunk non-blocking BlockUntilReady inside Thunk:Execute is illegal and leads to deadlocks. Run while loop asynchronously relying on AndThen callbacks. PiperOrigin-RevId: 655561921 --- xla/service/cpu/runtime/BUILD | 3 + xla/service/cpu/runtime/while_thunk.cc | 99 +++++++++++++-- xla/service/cpu/runtime/while_thunk.h | 9 +- xla/service/cpu/runtime/while_thunk_test.cc | 132 +++++++++++++++++++- xla/tsl/concurrency/async_value_ref.h | 10 ++ xla/tsl/concurrency/async_value_ref_test.cc | 32 ++++- 6 files changed, 265 insertions(+), 20 deletions(-) diff --git a/xla/service/cpu/runtime/BUILD b/xla/service/cpu/runtime/BUILD index c1f1d317a9f68e..1787141d0874ee 100644 --- a/xla/service/cpu/runtime/BUILD +++ b/xla/service/cpu/runtime/BUILD @@ -928,6 +928,7 @@ cc_library( "//xla/service:buffer_assignment", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -954,7 +955,9 @@ xla_cc_test( "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/status", + "@eigen_archive//:eigen3", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", diff --git a/xla/service/cpu/runtime/while_thunk.cc b/xla/service/cpu/runtime/while_thunk.cc index 1d7db4de408927..4e326b63a91706 100644 --- a/xla/service/cpu/runtime/while_thunk.cc +++ b/xla/service/cpu/runtime/while_thunk.cc @@ -15,9 +15,11 @@ limitations under the License. #include "xla/service/cpu/runtime/while_thunk.h" +#include #include #include +#include "absl/base/optimization.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "xla/runtime/buffer_use.h" @@ -26,6 +28,7 @@ limitations under the License. #include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" @@ -50,6 +53,60 @@ WhileThunk::WhileThunk(Info info, BufferAllocation::Slice cond_buffer, cond_executor_(std::move(cond_executor)), body_executor_(std::move(body_executor)) {} +tsl::AsyncValueRef WhileThunk::ExecuteAsync( + const ExecuteParams& params, tsl::AsyncValueRef dependency, + bool* condition) { + auto event = tsl::MakeConstructedAsyncValueRef(); + + // Allocate while loop iteration function on heap so we can detach its life + // time from the caller stack. + auto loop_fn = std::make_shared>(); + *loop_fn = [this, condition, ¶ms, event, + loop = loop_fn.get()](absl::Status status) { + // Dependency completed with an error. Forward it to the result event. + if (ABSL_PREDICT_FALSE(!status.ok())) { + event.SetError(std::move(status)); + return; + } + + while (*condition) { + auto body_event = body_executor_.Execute(params); + auto cond_event = body_event.FlatMap([this, ¶ms](ExecuteEvent) { + return cond_executor_.Execute(params); + }); + + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(cond_event.IsError())) { + event.SetError(cond_event.GetError()); + return; + } + + // If we don't know yet wether we should execute the next iteration or + // not, attach `AndThen` continuation to the `cond_event`. + if (!cond_event.IsAvailable()) { + cond_event.AndThen( + [loop](absl::Status status) { (*loop)(std::move(status)); }); + return; + } + + // At this point `*condition` should have been updated and we may continue + // executing the while loop in the current thread. + DCHECK(cond_event.IsAvailable()); + } + + // Successfully completed while loop iterations. + event.SetStateConcrete(); + }; + + // Kick-off loop execution once dependency event is available. + dependency.AndThen(*loop_fn); + + // Keep `loop_fn` alive until the end of the while loop execution. + event.AndThen([loop_fn = std::move(loop_fn)]() {}); + + return event; +} + tsl::AsyncValueRef WhileThunk::Execute( const ExecuteParams& params) { tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); @@ -60,23 +117,43 @@ tsl::AsyncValueRef WhileThunk::Execute( bool* condition = reinterpret_cast(cond_data.opaque()); - // TODO(ezhulenev): Remove `BlockUntilReady` calls and make WhileThunk - // asynchronous by chaining `Execute` calls via `AndThen` callbacks. - + // Execute `cond` thunk sequence to initialize the loop condition. auto init_event = cond_executor_.Execute(params); - tsl::BlockUntilReady(init_event); - if (init_event.IsError()) return init_event.GetError(); + + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(init_event.IsError())) { + return init_event.GetError(); + } + + // If we don't know if we should continue or not, switch to async execution + // mode using `init_event` as a dependency. + if (ABSL_PREDICT_FALSE(!init_event.IsAvailable())) { + return ExecuteAsync(params, std::move(init_event), condition); + } while (*condition) { auto body_event = body_executor_.Execute(params); - tsl::BlockUntilReady(body_event); - if (body_event.IsError()) return body_event.GetError(); - - auto cond_event = cond_executor_.Execute(params); - tsl::BlockUntilReady(cond_event); - if (cond_event.IsError()) return cond_event.GetError(); + auto cond_event = body_event.FlatMap([this, ¶ms](ExecuteEvent) { + return cond_executor_.Execute(params); + }); + + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(cond_event.IsError())) { + return cond_event.GetError(); + } + + // If we don't know if we should continue or not, switch to async execution + // mode using `cond_event` as a dependency. + if (ABSL_PREDICT_FALSE(!cond_event.IsAvailable())) { + return ExecuteAsync(params, std::move(cond_event), condition); + } + + // At this point `*condition` should have been updated and we may continue + // executing the while loop in the current thread. + DCHECK(cond_event.IsAvailable()); } + // Successfully completed while loop iterations. return OkExecuteEvent(); } diff --git a/xla/service/cpu/runtime/while_thunk.h b/xla/service/cpu/runtime/while_thunk.h index 29bc27dba0bc61..9c5a7af272468c 100644 --- a/xla/service/cpu/runtime/while_thunk.h +++ b/xla/service/cpu/runtime/while_thunk.h @@ -18,7 +18,6 @@ limitations under the License. #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/thunk.h" @@ -49,6 +48,14 @@ class WhileThunk final : public Thunk { WhileThunk(Info info, BufferAllocation::Slice cond_buffer, ThunkExecutor cond_executor, ThunkExecutor body_executor); + // If `cond` or `body` thunk sequence return unavailable async values, then + // we execute the while loop asynchronously by chaining `Execute` calls via + // `AndThen` callbacks. This execution mode adds significant overheads, so we + // try to avoid it when possible and run everything in the caller thread. + tsl::AsyncValueRef ExecuteAsync( + const ExecuteParams& params, tsl::AsyncValueRef dependency, + bool* condition); + BufferAllocation::Slice cond_buffer_; ThunkExecutor cond_executor_; ThunkExecutor body_executor_; diff --git a/xla/service/cpu/runtime/while_thunk_test.cc b/xla/service/cpu/runtime/while_thunk_test.cc index fbeef4d7c8697e..5da7202f7d9b7f 100644 --- a/xla/service/cpu/runtime/while_thunk_test.cc +++ b/xla/service/cpu/runtime/while_thunk_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/cpu/runtime/while_thunk.h" +#include +#include #include #include #include @@ -22,18 +24,29 @@ limitations under the License. #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/buffer_allocations.h" #include "xla/service/cpu/runtime/resource_use.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime/thunk_testlib.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" + +#define EIGEN_USE_THREADS + +#include "Eigen/ThreadPool" +#include "unsupported/Eigen/CXX11/Tensor" namespace xla::cpu { namespace { TEST(WhileThunkTest, BufferUses) { BufferAllocation alloc(0, 1024, 0); - BufferAllocation::Slice predicate_slice(&alloc, 0, sizeof(int32_t)); + BufferAllocation::Slice pred_slice(&alloc, 0, sizeof(char)); BufferAllocation::Slice cond_read_slice(&alloc, 10, 10); BufferAllocation::Slice body_read_slice(&alloc, 20, 10); @@ -47,18 +60,18 @@ TEST(WhileThunkTest, BufferUses) { TF_ASSERT_OK_AND_ASSIGN( auto thunk, - WhileThunk::Create({"while"}, predicate_slice, std::move(cond_sequence), + WhileThunk::Create({"while"}, pred_slice, std::move(cond_sequence), std::move(body_sequence))); EXPECT_EQ(thunk->buffer_uses().size(), 3); - EXPECT_EQ(thunk->buffer_uses()[0], BufferUse::Write(predicate_slice)); + EXPECT_EQ(thunk->buffer_uses()[0], BufferUse::Write(pred_slice)); EXPECT_EQ(thunk->buffer_uses()[1], BufferUse::Read(cond_read_slice)); EXPECT_EQ(thunk->buffer_uses()[2], BufferUse::Read(body_read_slice)); } TEST(WhileThunkTest, ResourceUses) { BufferAllocation alloc(0, 1024, 0); - BufferAllocation::Slice predicate_slice(&alloc, 0, sizeof(int32_t)); + BufferAllocation::Slice pred_slice(&alloc, 0, sizeof(char)); auto token0 = Resource::Create(Resource::kToken); auto token1 = Resource::Create(Resource::kToken); @@ -73,7 +86,7 @@ TEST(WhileThunkTest, ResourceUses) { TF_ASSERT_OK_AND_ASSIGN( auto thunk, - WhileThunk::Create({"while"}, predicate_slice, std::move(cond_sequence), + WhileThunk::Create({"while"}, pred_slice, std::move(cond_sequence), std::move(body_sequence))); EXPECT_EQ(thunk->resource_uses().size(), 2); @@ -81,5 +94,114 @@ TEST(WhileThunkTest, ResourceUses) { EXPECT_EQ(thunk->resource_uses()[1], ResourceUse::Read(token1)); } +// Below are fake thunks that always launch tasks into the intra-op thread pool, +// so that we can test that WhileThunk::Execute correctly handles asynchronous +// cond and body thunk sequences. + +class CondThunk : public Thunk { + public: + CondThunk(size_t counter, BufferAllocation::Slice pred_slice) + : Thunk(Kind::kKernel, {"cond"}), + counter_(counter + 1), + pred_slice_(pred_slice) {} + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final { + auto event = tsl::MakeConstructedAsyncValueRef(); + + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase predicate_mem, + params.buffer_allocations->GetDeviceAddress(pred_slice_)); + bool* predicate = reinterpret_cast(predicate_mem.opaque()); + + // Continue while loop until counter reaches 0. + *predicate = counter_.fetch_sub(1) > 1; + + params.intra_op_threadpool->getPool()->Schedule( + [event] { event.SetStateConcrete(); }); + + return event; + } + + BufferUses buffer_uses() const final { + return {BufferUse::Write(pred_slice_)}; + } + + private: + std::atomic counter_; + BufferAllocation::Slice pred_slice_; +}; + +class BodyThunk : public Thunk { + public: + explicit BodyThunk(BufferAllocation::Slice counter_slice) + : Thunk(Kind::kKernel, {"body"}), counter_slice_(counter_slice) {} + + tsl::AsyncValueRef Execute(const ExecuteParams& params) final { + auto event = tsl::MakeConstructedAsyncValueRef(); + + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase counter_mem, + params.buffer_allocations->GetDeviceAddress(counter_slice_)); + + int32_t* counter = reinterpret_cast(counter_mem.opaque()); + ++*counter; + + params.intra_op_threadpool->getPool()->Schedule( + [event] { event.SetStateConcrete(); }); + + return event; + } + + BufferUses buffer_uses() const final { return {}; } + + private: + BufferAllocation::Slice counter_slice_; +}; + +TEST(WhileThunkTest, NonBlockingExecute) { + static constexpr size_t kNumIterations = 100; + + BufferAllocation pred_alloc(0, sizeof(char), 0); + BufferAllocation cnt_alloc(1, sizeof(int32_t), 0); + + BufferAllocation::Slice pred_slice(&pred_alloc, 0, sizeof(char)); + BufferAllocation::Slice cnt_slice(&cnt_alloc, 0, sizeof(int32_t)); + + std::vector buffers; + std::vector predicate = {false}; + std::vector counter = {0}; + + buffers.emplace_back(se::DeviceMemoryBase(predicate.data(), sizeof(char))); + buffers.emplace_back(se::DeviceMemoryBase(counter.data(), sizeof(int32_t))); + + BufferAllocations allocations(buffers); + + ThunkSequence cond_sequence; + cond_sequence.push_back( + std::make_unique(kNumIterations, pred_slice)); + + ThunkSequence body_sequence; + body_sequence.push_back(std::make_unique(cnt_slice)); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, + WhileThunk::Create({"while"}, pred_slice, std::move(cond_sequence), + std::move(body_sequence))); + + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "while-test", 8); + Eigen::ThreadPoolDevice device(thread_pool.AsEigenThreadPool(), + thread_pool.NumThreads()); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + params.intra_op_threadpool = &device; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + EXPECT_EQ(counter[0], kNumIterations); +} + } // namespace } // namespace xla::cpu diff --git a/xla/tsl/concurrency/async_value_ref.h b/xla/tsl/concurrency/async_value_ref.h index ebd5a0c5a7221c..ca1f4133dad564 100644 --- a/xla/tsl/concurrency/async_value_ref.h +++ b/xla/tsl/concurrency/async_value_ref.h @@ -719,6 +719,13 @@ class AsyncValuePtr { template , std::enable_if_t>* = nullptr> AsyncValueRef FlatMap(F&& f) { + // If async value is in concrete state, we can immediately call the functor. + // We don't handle errors here and prefer a generic code path below because + // error handling is never on a performance critical path. + if (ABSL_PREDICT_TRUE(IsConcrete())) { + return f(get()); + } + auto promise = MakePromise(); AndThen([f = std::forward(f), promise, ptr = *this]() mutable { if (ABSL_PREDICT_FALSE(ptr.IsError())) { @@ -735,6 +742,9 @@ class AsyncValuePtr { std::enable_if_t>* = nullptr> AsyncValueRef FlatMap(AsyncValue::Executor& executor, F&& f) { + // We don't have a special handling for concrete values here because + // we must execute user functor on a separate executor and can't call it in + // the caller thread. auto promise = MakePromise(); // We don't know when the executor will run the callback, so we need to // copy the AsyncValueRef to keep the underlying value alive. diff --git a/xla/tsl/concurrency/async_value_ref_test.cc b/xla/tsl/concurrency/async_value_ref_test.cc index 6e73dc4a850834..2c4ce86933dbf9 100644 --- a/xla/tsl/concurrency/async_value_ref_test.cc +++ b/xla/tsl/concurrency/async_value_ref_test.cc @@ -354,7 +354,7 @@ TEST(AsyncValueRefTest, FlatMapAvailable) { AsyncValueRef ref = MakeAvailableAsyncValueRef(42); AsyncValueRef fmapped_to_float = ref.FlatMap([](int32_t value) { - return MakeAvailableAsyncValueRef(1.0f * value); + return MakeAvailableAsyncValueRef(static_cast(value)); }); EXPECT_TRUE(fmapped_to_float.IsAvailable()); @@ -365,7 +365,7 @@ TEST(AsyncValueRefTest, FlatMapUnavailable) { AsyncValueRef ref = MakeConstructedAsyncValueRef(42); AsyncValueRef fmapped_to_float = ref.FlatMap([](int32_t value) { - return MakeAvailableAsyncValueRef(1.0f * value); + return MakeAvailableAsyncValueRef(static_cast(value)); }); EXPECT_FALSE(fmapped_to_float.IsAvailable()); @@ -375,6 +375,32 @@ TEST(AsyncValueRefTest, FlatMapUnavailable) { EXPECT_EQ(fmapped_to_float.get(), 42.0f); } +TEST(AsyncValueRefTest, FlatMapAvailableError) { + AsyncValueRef ref = + MakeErrorAsyncValueRef(absl::InternalError("error")); + + AsyncValueRef fmapped_to_float = ref.FlatMap([](int32_t value) { + return MakeAvailableAsyncValueRef(static_cast(value)); + }); + + EXPECT_TRUE(fmapped_to_float.IsError()); + EXPECT_EQ(fmapped_to_float.GetError(), absl::InternalError("error")); +} + +TEST(AsyncValueRefTest, FlatMapUnavailableError) { + AsyncValueRef ref = MakeConstructedAsyncValueRef(42); + + AsyncValueRef fmapped_to_float = ref.FlatMap([](int32_t value) { + return MakeAvailableAsyncValueRef(static_cast(value)); + }); + + EXPECT_FALSE(fmapped_to_float.IsAvailable()); + ref.SetError(absl::InternalError("error")); + + EXPECT_TRUE(fmapped_to_float.IsError()); + EXPECT_EQ(fmapped_to_float.GetError(), absl::InternalError("error")); +} + struct DeferredExecutor : public AsyncValue::Executor { void Execute(Task task) final { tasks.push_back(std::move(task)); } @@ -480,7 +506,7 @@ TEST(AsyncValueRefTest, FlatMapAvailableOnExecutor) { DeferredExecutor executor; AsyncValueRef fmapped_to_float = ref.FlatMap(executor, [](int32_t value) { - return MakeAvailableAsyncValueRef(1.0f * value); + return MakeAvailableAsyncValueRef(static_cast(value)); }); ref.SetStateConcrete(); From bc4b6cb73b10124f16213e89d5758b637a72b8b6 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Wed, 24 Jul 2024 08:08:34 -0700 Subject: [PATCH 115/376] [XLA] Add option to make cross-program prefetch more permissive. PiperOrigin-RevId: 655570179 --- .../memory_space_assignment/algorithm.cc | 20 +++++++------ .../memory_space_assignment_test.cc | 28 +++++++++++++++++++ xla/service/memory_space_assignment/options.h | 5 ++++ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index eeda3a47b676a2..d18a69bc5635dd 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -110,7 +110,7 @@ std::string VectorToString(const std::vector& v, return absl::StrCat("[ ", absl::StrJoin(elements, ", "), " ]"); } -bool LooksLikeAnActivation(const HloInstruction* inst) { +bool LooksLikeAnActivation(const HloInstruction* inst, bool permissive_mode) { for (HloInstruction* user : inst->users()) { switch (user->opcode()) { case HloOpcode::kConvolution: @@ -127,7 +127,8 @@ bool LooksLikeAnActivation(const HloInstruction* inst) { case HloOpcode::kFusion: for (int i = 0; i < user->operand_count(); ++i) { if (user->operand(i) == inst && - LooksLikeAnActivation(user->fused_parameter(i))) { + LooksLikeAnActivation(user->fused_parameter(i), + permissive_mode)) { return true; } } @@ -135,14 +136,14 @@ bool LooksLikeAnActivation(const HloInstruction* inst) { case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kTranspose: - if (LooksLikeAnActivation(user)) { + if (LooksLikeAnActivation(user, permissive_mode)) { return true; } break; case HloOpcode::kCopy: if (user->IsFused() && (user == user->parent()->root_instruction())) { user = user->parent()->FusionInstruction(); - if (LooksLikeAnActivation(user)) { + if (LooksLikeAnActivation(user, permissive_mode)) { return true; } else { break; @@ -155,7 +156,7 @@ bool LooksLikeAnActivation(const HloInstruction* inst) { inst) != user->operands().end()) { return true; } - if (LooksLikeAnActivation(user)) { + if (LooksLikeAnActivation(user, permissive_mode)) { return true; } break; @@ -165,12 +166,14 @@ bool LooksLikeAnActivation(const HloInstruction* inst) { user->operands().end(), inst) != user->operands().end()) { return true; } - if (LooksLikeAnActivation(user)) { + if (LooksLikeAnActivation(user, permissive_mode)) { return true; } break; default: - return true; + // Permissive mode assumes the tensor is not an activation when we + // couldn't explicitly determine that it is not an activation. + return !permissive_mode; } } return false; @@ -262,7 +265,8 @@ bool IsCrossProgramPrefetchCandidate(const HloValue& value, return (inst->opcode() == HloOpcode::kGetTupleElement || inst->opcode() == HloOpcode::kParameter) && - !LooksLikeAnActivation(inst); + !LooksLikeAnActivation( + inst, options.cross_program_prefetch_permissive_mode); }); } diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 2a8a7cfe69826e..3888dfb47fbe0c 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -9864,6 +9864,34 @@ ENTRY %main { op::Fusion())); } +TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchPermissiveMode) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +fused_computation { + param_0 = f32[2] parameter(0) + param_1 = f32[4,2] parameter(1) + broadcast = f32[4,2] broadcast(param_0), dimensions={1} + ROOT multiply = f32[4,2] multiply(broadcast, param_1) +} + +ENTRY entry { + p0 = f32[2] parameter(0) + p1 = f32[4,2] parameter(1) + fusion = f32[4,2] fusion(p0, p1), kind=kLoop, calls=fused_computation + ROOT negate = f32[4,2] negate(fusion) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.cross_program_prefetch_permissive_mode = true; + AssignMemorySpace(module.get(), options); + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); +} + // Test description: // - Setup: Make sure p1 can not be prefetched to alternate memory until after // instruction c. We do this by causing p0 to be prefetched to alternate diff --git a/xla/service/memory_space_assignment/options.h b/xla/service/memory_space_assignment/options.h index 7926d5611c4958..3a1d8488118afb 100644 --- a/xla/service/memory_space_assignment/options.h +++ b/xla/service/memory_space_assignment/options.h @@ -172,6 +172,11 @@ struct Options { // TODO(tjablin): Use a heuristic to determine this automatically. int max_cross_program_prefetches = 1; + // If false, we assume tensors that we couldn't explicitly determine to be + // activations are activations. If true, we assume these aren't activations, + // so they may be cross-program-prefetch candidates. + bool cross_program_prefetch_permissive_mode = false; + // Enable redundant eviction optimization in/around while loops. If enabled, // this optimization would keep a copy of the buffer in the default memory in // addition to alternate memory to eliminate redundant evictions. From 8223a042557f5202fc677030c233f4d46edc38ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 24 Jul 2024 08:17:48 -0700 Subject: [PATCH 116/376] Integrate LLVM at llvm/llvm-project@84658fb82b67 Updates LLVM usage to match [84658fb82b67](https://github.com/llvm/llvm-project/commit/84658fb82b67) PiperOrigin-RevId: 655573384 --- third_party/llvm/generated.patch | 1926 ++++++----------- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/workspace.bzl | 4 +- .../tsl/third_party/llvm/generated.patch | 1926 ++++++----------- .../tsl/third_party/llvm/workspace.bzl | 4 +- 5 files changed, 1214 insertions(+), 2650 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 7af6db90fd2b4f..c7f7475c35588c 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -2,1355 +2,637 @@ Auto generated patch. Do not edit or delete it, even if empty. diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst -@@ -40,8 +40,6 @@ - - Setting the deprecated CMake variable ``GCC_INSTALL_PREFIX`` (which sets the - default ``--gcc-toolchain=``) now leads to a fatal error. +@@ -750,9 +750,6 @@ + - Clang now specifies that using ``auto`` in a lambda parameter is a C++14 extension when + appropriate. (`#46059: `_). --- The ``le32`` and ``le64`` targets have been removed. +-- Clang now adds source file infomation for template instantiations as ``event["args"]["filename"]``. This +- added behind an option ``-ftime-trace-verbose``. This is expected to increase the size of trace by 2-3 times. - - C/C++ Language Potentially Breaking Changes - ------------------------------------------- - -diff -ruN --strip-trailing-cr a/clang/docs/tools/clang-formatted-files.txt b/clang/docs/tools/clang-formatted-files.txt ---- a/clang/docs/tools/clang-formatted-files.txt -+++ b/clang/docs/tools/clang-formatted-files.txt -@@ -362,6 +362,7 @@ - clang/lib/Basic/Targets/BPF.h - clang/lib/Basic/Targets/Hexagon.h - clang/lib/Basic/Targets/Lanai.h -+clang/lib/Basic/Targets/Le64.h - clang/lib/Basic/Targets/M68k.h - clang/lib/Basic/Targets/MSP430.h - clang/lib/Basic/Targets/NVPTX.cpp -diff -ruN --strip-trailing-cr a/clang/lib/Basic/CMakeLists.txt b/clang/lib/Basic/CMakeLists.txt ---- a/clang/lib/Basic/CMakeLists.txt -+++ b/clang/lib/Basic/CMakeLists.txt -@@ -102,6 +102,7 @@ - Targets/DirectX.cpp - Targets/Hexagon.cpp - Targets/Lanai.cpp -+ Targets/Le64.cpp - Targets/LoongArch.cpp - Targets/M68k.cpp - Targets/MSP430.cpp -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/Le64.cpp b/clang/lib/Basic/Targets/Le64.cpp ---- a/clang/lib/Basic/Targets/Le64.cpp -+++ b/clang/lib/Basic/Targets/Le64.cpp -@@ -0,0 +1,30 @@ -+//===--- Le64.cpp - Implement Le64 target feature support -----------------===// -+// -+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -+// See https://llvm.org/LICENSE.txt for license information. -+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -+// -+//===----------------------------------------------------------------------===// -+// -+// This file implements Le64 TargetInfo objects. -+// -+//===----------------------------------------------------------------------===// -+ -+#include "Le64.h" -+#include "Targets.h" -+#include "clang/Basic/Builtins.h" -+#include "clang/Basic/MacroBuilder.h" -+#include "clang/Basic/TargetBuiltins.h" -+ -+using namespace clang; -+using namespace clang::targets; -+ -+ArrayRef Le64TargetInfo::getTargetBuiltins() const { -+ return {}; -+} -+ -+void Le64TargetInfo::getTargetDefines(const LangOptions &Opts, -+ MacroBuilder &Builder) const { -+ DefineStd(Builder, "unix", Opts); -+ defineCPUMacros(Builder, "le64", /*Tuning=*/false); -+} -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/Le64.h b/clang/lib/Basic/Targets/Le64.h ---- a/clang/lib/Basic/Targets/Le64.h -+++ b/clang/lib/Basic/Targets/Le64.h -@@ -0,0 +1,64 @@ -+//===--- Le64.h - Declare Le64 target feature support -----------*- C++ -*-===// -+// -+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -+// See https://llvm.org/LICENSE.txt for license information. -+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -+// -+//===----------------------------------------------------------------------===// -+// -+// This file declares Le64 TargetInfo objects. -+// -+//===----------------------------------------------------------------------===// -+ -+#ifndef LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H -+#define LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H -+ -+#include "clang/Basic/TargetInfo.h" -+#include "clang/Basic/TargetOptions.h" -+#include "llvm/Support/Compiler.h" -+#include "llvm/TargetParser/Triple.h" -+ -+namespace clang { -+namespace targets { -+ -+class LLVM_LIBRARY_VISIBILITY Le64TargetInfo : public TargetInfo { -+ -+public: -+ Le64TargetInfo(const llvm::Triple &Triple, const TargetOptions &) -+ : TargetInfo(Triple) { -+ NoAsmVariants = true; -+ LongWidth = LongAlign = PointerWidth = PointerAlign = 64; -+ MaxAtomicPromoteWidth = MaxAtomicInlineWidth = 64; -+ resetDataLayout("e-m:e-v128:32-v16:16-v32:32-v96:32-n8:16:32:64-S128"); -+ } -+ -+ void getTargetDefines(const LangOptions &Opts, -+ MacroBuilder &Builder) const override; -+ -+ ArrayRef getTargetBuiltins() const override; -+ -+ BuiltinVaListKind getBuiltinVaListKind() const override { -+ return TargetInfo::PNaClABIBuiltinVaList; -+ } -+ -+ std::string_view getClobbers() const override { return ""; } -+ -+ ArrayRef getGCCRegNames() const override { -+ return std::nullopt; -+ } -+ -+ ArrayRef getGCCRegAliases() const override { -+ return std::nullopt; -+ } -+ -+ bool validateAsmConstraint(const char *&Name, -+ TargetInfo::ConstraintInfo &Info) const override { -+ return false; -+ } -+ -+ bool hasProtectedVisibility() const override { return false; } -+}; -+ -+} // namespace targets -+} // namespace clang -+#endif // LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/OSTargets.h b/clang/lib/Basic/Targets/OSTargets.h ---- a/clang/lib/Basic/Targets/OSTargets.h -+++ b/clang/lib/Basic/Targets/OSTargets.h -@@ -841,6 +841,9 @@ - "i64:64-i128:128-n8:16:32:64-S128"); - } else if (Triple.getArch() == llvm::Triple::mipsel) { - // Handled on mips' setDataLayout. -+ } else { -+ assert(Triple.getArch() == llvm::Triple::le32); -+ this->resetDataLayout("e-p:32:32-i64:64"); - } - } - }; -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets.cpp b/clang/lib/Basic/Targets.cpp ---- a/clang/lib/Basic/Targets.cpp -+++ b/clang/lib/Basic/Targets.cpp -@@ -23,6 +23,7 @@ - #include "Targets/DirectX.h" - #include "Targets/Hexagon.h" - #include "Targets/Lanai.h" -+#include "Targets/Le64.h" - #include "Targets/LoongArch.h" - #include "Targets/M68k.h" - #include "Targets/MSP430.h" -@@ -343,6 +344,17 @@ - return std::make_unique(Triple, Opts); - } - -+ case llvm::Triple::le32: -+ switch (os) { -+ case llvm::Triple::NaCl: -+ return std::make_unique>(Triple, Opts); -+ default: -+ return nullptr; -+ } -+ -+ case llvm::Triple::le64: -+ return std::make_unique(Triple, Opts); -+ - case llvm::Triple::ppc: - switch (os) { - case llvm::Triple::Linux: -diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp ---- a/clang/lib/CodeGen/CodeGenModule.cpp -+++ b/clang/lib/CodeGen/CodeGenModule.cpp -@@ -116,6 +116,8 @@ - default: - return createDefaultTargetCodeGenInfo(CGM); - -+ case llvm::Triple::le32: -+ return createPNaClTargetCodeGenInfo(CGM); - case llvm::Triple::m68k: - return createM68kTargetCodeGenInfo(CGM); - case llvm::Triple::mips: -diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp ---- a/clang/lib/CodeGen/ItaniumCXXABI.cpp -+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp -@@ -576,6 +576,13 @@ - return new XLCXXABI(CGM); - - case TargetCXXABI::GenericItanium: -+ if (CGM.getContext().getTargetInfo().getTriple().getArch() -+ == llvm::Triple::le32) { -+ // For PNaCl, use ARM-style method pointers so that PNaCl code -+ // does not assume anything about the alignment of function -+ // pointers. -+ return new ItaniumCXXABI(CGM, /*UseARMMethodPtrABI=*/true); -+ } - return new ItaniumCXXABI(CGM); - - case TargetCXXABI::Microsoft: + Improvements to Coverage Mapping + -------------------------------- + +diff -ruN --strip-trailing-cr a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td +--- a/clang/include/clang/Driver/Options.td ++++ b/clang/include/clang/Driver/Options.td +@@ -3998,10 +3998,6 @@ + HelpText<"Minimum time granularity (in microseconds) traced by time profiler">, + Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, + MarshallingInfoInt, "500u">; +-def ftime_trace_verbose : Joined<["-"], "ftime-trace-verbose">, Group, +- HelpText<"Make time trace capture verbose event details (e.g. source filenames). This can increase the size of the output by 2-3 times">, +- Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, +- MarshallingInfoFlag>; + def ftime_trace_EQ : Joined<["-"], "ftime-trace=">, Group, + HelpText<"Similar to -ftime-trace. Specify the JSON file or a directory which will contain the JSON file">, + Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, +diff -ruN --strip-trailing-cr a/clang/include/clang/Frontend/FrontendOptions.h b/clang/include/clang/Frontend/FrontendOptions.h +--- a/clang/include/clang/Frontend/FrontendOptions.h ++++ b/clang/include/clang/Frontend/FrontendOptions.h +@@ -580,11 +580,6 @@ + /// Minimum time granularity (in microseconds) traced by time profiler. + unsigned TimeTraceGranularity; + +- /// Make time trace capture verbose event details (e.g. source filenames). +- /// This can increase the size of the output by 2-3 times. +- LLVM_PREFERRED_TYPE(bool) +- unsigned TimeTraceVerbose : 1; +- + /// Path which stores the output files for -ftime-trace + std::string TimeTracePath; + +@@ -606,8 +601,7 @@ + EmitSymbolGraph(false), EmitExtensionSymbolGraphs(false), + EmitSymbolGraphSymbolLabelsForTesting(false), + EmitPrettySymbolGraphs(false), GenReducedBMI(false), +- UseClangIRPipeline(false), TimeTraceGranularity(500), +- TimeTraceVerbose(false) {} ++ UseClangIRPipeline(false), TimeTraceGranularity(500) {} + + /// getInputKindForExtension - Return the appropriate input kind for a file + /// extension. For example, "c" would return Language::C. diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp -@@ -3815,6 +3815,12 @@ - if (UseBuiltins) - A->render(Args, CmdArgs); +@@ -6757,7 +6757,6 @@ + if (const char *Name = C.getTimeTraceFile(&JA)) { + CmdArgs.push_back(Args.MakeArgString("-ftime-trace=" + Twine(Name))); + Args.AddLastArg(CmdArgs, options::OPT_ftime_trace_granularity_EQ); +- Args.AddLastArg(CmdArgs, options::OPT_ftime_trace_verbose); } -+ -+ // le32-specific flags: -+ // -fno-math-builtin: clang should not convert math builtins to intrinsics -+ // by default. -+ if (TC.getArch() == llvm::Triple::le32) -+ CmdArgs.push_back("-fno-math-builtin"); - } - bool Driver::getDefaultModuleCachePath(SmallVectorImpl &Result) { -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/bitfield-access-pad.c b/clang/test/CodeGen/bitfield-access-pad.c ---- a/clang/test/CodeGen/bitfield-access-pad.c -+++ b/clang/test/CodeGen/bitfield-access-pad.c -@@ -16,6 +16,7 @@ - // Configs that have expensive unaligned access - // Little Endian - // RUN: %clang_cc1 -triple=hexagon-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT-T %s -+// RUN: %clang_cc1 -triple=le64-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT-T %s - - // Big endian - // RUN: %clang_cc1 -triple=m68k-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT-T %s -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/bitfield-access-unit.c b/clang/test/CodeGen/bitfield-access-unit.c ---- a/clang/test/CodeGen/bitfield-access-unit.c -+++ b/clang/test/CodeGen/bitfield-access-unit.c -@@ -53,8 +53,8 @@ - // RUN: %clang_cc1 -triple=sparc-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT-STRICT %s - // RUN: %clang_cc1 -triple=tce-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT-STRICT %s - --// m68-elf is a strict alignment ISA with 4-byte aligned 64-bit or 2-byte --// aligned 32-bit integer types. This more compex to describe here. -+// Both le64-elf and m68-elf are strict alignment ISAs with 4-byte aligned -+// 64-bit or 2-byte aligned 32-bit integer types. This more compex to describe here. - - // If unaligned access is expensive don't stick these together. - struct A { -diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bitfield-access-empty.cpp b/clang/test/CodeGenCXX/bitfield-access-empty.cpp ---- a/clang/test/CodeGenCXX/bitfield-access-empty.cpp -+++ b/clang/test/CodeGenCXX/bitfield-access-empty.cpp -@@ -26,6 +26,7 @@ - // RUN: %clang_cc1 -triple=bpf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=csky %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=hexagon-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s -+// RUN: %clang_cc1 -triple=le64-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=loongarch32-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=nvptx-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=riscv32 %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s -diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bitfield-access-tail.cpp b/clang/test/CodeGenCXX/bitfield-access-tail.cpp ---- a/clang/test/CodeGenCXX/bitfield-access-tail.cpp -+++ b/clang/test/CodeGenCXX/bitfield-access-tail.cpp -@@ -26,6 +26,7 @@ - // RUN: %clang_cc1 -triple=bpf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT64 %s - // RUN: %clang_cc1 -triple=csky %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s - // RUN: %clang_cc1 -triple=hexagon-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s -+// RUN: %clang_cc1 -triple=le64-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT64 %s - // RUN: %clang_cc1 -triple=loongarch32-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s - // RUN: %clang_cc1 -triple=nvptx-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s - // RUN: %clang_cc1 -triple=riscv32 %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s -diff -ruN --strip-trailing-cr a/clang/test/Preprocessor/predefined-macros-no-warnings.c b/clang/test/Preprocessor/predefined-macros-no-warnings.c ---- a/clang/test/Preprocessor/predefined-macros-no-warnings.c -+++ b/clang/test/Preprocessor/predefined-macros-no-warnings.c -@@ -75,6 +75,8 @@ - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple m68k - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple m68k-linux - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple m68k-netbsd -+// RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple le32-nacl -+// RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple le64 - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc-freebsd - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc-netbsd -diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h ---- a/llvm/include/llvm/IR/PatternMatch.h -+++ b/llvm/include/llvm/IR/PatternMatch.h -@@ -1550,27 +1550,23 @@ - template - struct CmpClass_match { -- PredicateTy *Predicate; -+ PredicateTy &Predicate; - LHS_t L; - RHS_t R; - - // The evaluation order is always stable, regardless of Commutability. - // The LHS is always matched first. - CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) -- : Predicate(&Pred), L(LHS), R(RHS) {} -- CmpClass_match(const LHS_t &LHS, const RHS_t &RHS) -- : Predicate(nullptr), L(LHS), R(RHS) {} -+ : Predicate(Pred), L(LHS), R(RHS) {} - - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { -- if (Predicate) -- *Predicate = I->getPredicate(); -+ Predicate = I->getPredicate(); - return true; - } else if (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))) { -- if (Predicate) -- *Predicate = I->getSwappedPredicate(); -+ Predicate = I->getSwappedPredicate(); - return true; - } - } -@@ -1599,19 +1595,22 @@ - template - inline CmpClass_match - m_Cmp(const LHS &L, const RHS &R) { -- return CmpClass_match(L, R); -+ CmpInst::Predicate Unused; -+ return CmpClass_match(Unused, L, R); - } + if (Arg *A = Args.getLastArg(options::OPT_ftrapv_handler_EQ)) { +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp +--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp ++++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp +@@ -3426,16 +3426,11 @@ + return true; + + llvm::TimeTraceScope TimeScope("InstantiateClass", [&]() { +- llvm::TimeTraceMetadata M; +- llvm::raw_string_ostream OS(M.Detail); ++ std::string Name; ++ llvm::raw_string_ostream OS(Name); + Instantiation->getNameForDiagnostic(OS, getPrintingPolicy(), + /*Qualified=*/true); +- if (llvm::isTimeTraceVerbose()) { +- auto Loc = SourceMgr.getExpansionLoc(Instantiation->getLocation()); +- M.File = SourceMgr.getFilename(Loc); +- M.Line = SourceMgr.getExpansionLineNumber(Loc); +- } +- return M; ++ return Name; + }); + + Pattern = PatternDef; +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp +--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp ++++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp +@@ -4966,16 +4966,11 @@ + } - template - inline CmpClass_match - m_ICmp(const LHS &L, const RHS &R) { -- return CmpClass_match(L, R); -+ ICmpInst::Predicate Unused; -+ return CmpClass_match(Unused, L, R); + llvm::TimeTraceScope TimeScope("InstantiateFunction", [&]() { +- llvm::TimeTraceMetadata M; +- llvm::raw_string_ostream OS(M.Detail); ++ std::string Name; ++ llvm::raw_string_ostream OS(Name); + Function->getNameForDiagnostic(OS, getPrintingPolicy(), + /*Qualified=*/true); +- if (llvm::isTimeTraceVerbose()) { +- auto Loc = SourceMgr.getExpansionLoc(Function->getLocation()); +- M.File = SourceMgr.getFilename(Loc); +- M.Line = SourceMgr.getExpansionLineNumber(Loc); +- } +- return M; ++ return Name; + }); + + // If we're performing recursive template instantiation, create our own +diff -ruN --strip-trailing-cr a/clang/test/Driver/ftime-trace-sections.cpp b/clang/test/Driver/ftime-trace-sections.cpp +--- a/clang/test/Driver/ftime-trace-sections.cpp ++++ b/clang/test/Driver/ftime-trace-sections.cpp +@@ -1,5 +1,5 @@ + // RUN: rm -rf %t && mkdir %t && cd %t +-// RUN: %clangxx -S -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -ftime-trace -ftime-trace-granularity=0 -o out %s + // RUN: %python %S/ftime-trace-sections.py < out.json + + template +diff -ruN --strip-trailing-cr a/clang/test/Driver/ftime-trace.cpp b/clang/test/Driver/ftime-trace.cpp +--- a/clang/test/Driver/ftime-trace.cpp ++++ b/clang/test/Driver/ftime-trace.cpp +@@ -1,18 +1,18 @@ + // RUN: rm -rf %t && mkdir -p %t && cd %t +-// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace -ftime-trace-granularity=0 -o out %s + // RUN: cat out.json \ + // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ + // RUN: | FileCheck %s +-// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=new-name.json -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=new-name.json -ftime-trace-granularity=0 -o out %s + // RUN: cat new-name.json \ + // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ + // RUN: | FileCheck %s + // RUN: mkdir dir1 dir2 +-// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir1 -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir1 -ftime-trace-granularity=0 -o out %s + // RUN: cat dir1/out.json \ + // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ + // RUN: | FileCheck %s +-// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir2/ -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir2/ -ftime-trace-granularity=0 -o out %s + // RUN: cat dir2/out.json \ + // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ + // RUN: | FileCheck %s +@@ -34,33 +34,32 @@ + // RUN: mkdir d e f && cp %s d/a.cpp && touch d/b.c + + /// TODO: Support -fno-integrated-as. +-// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -fintegrated-as d/a.cpp -o e/a.o 2>&1 | FileCheck %s --check-prefix=COMPILE1 +-// COMPILE1: -cc1{{.*}} "-ftime-trace=e/a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" ++// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -fintegrated-as d/a.cpp -o e/a.o 2>&1 | FileCheck %s --check-prefix=COMPILE1 ++// COMPILE1: -cc1{{.*}} "-ftime-trace=e/a.json" "-ftime-trace-granularity=0" + +-// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=COMPILE2 +-// COMPILE2: -cc1{{.*}} "-ftime-trace=f/a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +-// COMPILE2: -cc1{{.*}} "-ftime-trace=f/b.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" ++// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=COMPILE2 ++// COMPILE2: -cc1{{.*}} "-ftime-trace=f/a.json" "-ftime-trace-granularity=0" ++// COMPILE2: -cc1{{.*}} "-ftime-trace=f/b.json" "-ftime-trace-granularity=0" + + /// -o specifies the link output. Create ${output}-${basename}.json. +-// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o e/x 2>&1 | FileCheck %s --check-prefix=LINK1 +-// LINK1: -cc1{{.*}} "-ftime-trace=e/x-a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +-// LINK1: -cc1{{.*}} "-ftime-trace=e/x-b.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" ++// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -o e/x 2>&1 | FileCheck %s --check-prefix=LINK1 ++// LINK1: -cc1{{.*}} "-ftime-trace=e/x-a.json" "-ftime-trace-granularity=0" ++// LINK1: -cc1{{.*}} "-ftime-trace=e/x-b.json" "-ftime-trace-granularity=0" + + /// -dumpdir is f/g, not ending with a path separator. We create f/g${basename}.json. +-// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o e/x -dumpdir f/g 2>&1 | FileCheck %s --check-prefix=LINK2 +-// LINK2: -cc1{{.*}} "-ftime-trace=f/ga.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +-// LINK2: -cc1{{.*}} "-ftime-trace=f/gb.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +- +-// RUN: %clang -### -ftime-trace=e -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o f/x -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=LINK3 +-// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}a-{{[^.]*}}.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +-// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}b-{{[^.]*}}.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" ++// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -o e/x -dumpdir f/g 2>&1 | FileCheck %s --check-prefix=LINK2 ++// LINK2: -cc1{{.*}} "-ftime-trace=f/ga.json" "-ftime-trace-granularity=0" ++// LINK2: -cc1{{.*}} "-ftime-trace=f/gb.json" "-ftime-trace-granularity=0" ++ ++// RUN: %clang -### -ftime-trace=e -ftime-trace-granularity=0 d/a.cpp d/b.c -o f/x -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=LINK3 ++// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}a-{{[^.]*}}.json" "-ftime-trace-granularity=0" ++// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}b-{{[^.]*}}.json" "-ftime-trace-granularity=0" + +-// RUN: %clang -### -ftime-trace -ftime-trace=e -ftime-trace-granularity=1 -ftime-trace-verbose -xassembler d/a.cpp 2>&1 | \ ++// RUN: %clang -### -ftime-trace -ftime-trace=e -ftime-trace-granularity=1 -xassembler d/a.cpp 2>&1 | \ + // RUN: FileCheck %s --check-prefix=UNUSED + // UNUSED: warning: argument unused during compilation: '-ftime-trace' + // UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace=e' + // UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace-granularity=1' +-// UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace-verbose' + // UNUSED-NOT: warning: + + template +diff -ruN --strip-trailing-cr a/clang/tools/driver/cc1_main.cpp b/clang/tools/driver/cc1_main.cpp +--- a/clang/tools/driver/cc1_main.cpp ++++ b/clang/tools/driver/cc1_main.cpp +@@ -241,8 +241,7 @@ + + if (!Clang->getFrontendOpts().TimeTracePath.empty()) { + llvm::timeTraceProfilerInitialize( +- Clang->getFrontendOpts().TimeTraceGranularity, Argv0, +- Clang->getFrontendOpts().TimeTraceVerbose); ++ Clang->getFrontendOpts().TimeTraceGranularity, Argv0); + } + // --print-supported-cpus takes priority over the actual compilation. + if (Clang->getFrontendOpts().PrintSupportedCPUs) +diff -ruN --strip-trailing-cr a/clang/unittests/Support/TimeProfilerTest.cpp b/clang/unittests/Support/TimeProfilerTest.cpp +--- a/clang/unittests/Support/TimeProfilerTest.cpp ++++ b/clang/unittests/Support/TimeProfilerTest.cpp +@@ -10,15 +10,11 @@ + #include "clang/Frontend/FrontendActions.h" + #include "clang/Lex/PreprocessorOptions.h" + +-#include "llvm/ADT/StringMap.h" + #include "llvm/Support/JSON.h" +-#include "llvm/Support/Path.h" + #include "llvm/Support/TimeProfiler.h" +-#include "llvm/Support/VirtualFileSystem.h" + #include + + #include "gtest/gtest.h" +-#include + + using namespace clang; + using namespace llvm; +@@ -27,8 +23,7 @@ + + // Should be called before testing. + void setupProfiler() { +- timeTraceProfilerInitialize(/*TimeTraceGranularity=*/0, "test", +- /*TimeTraceVerbose=*/true); ++ timeTraceProfilerInitialize(/*TimeTraceGranularity=*/0, "test"); } - template - inline CmpClass_match - m_FCmp(const LHS &L, const RHS &R) { -- return CmpClass_match(L, R); -+ FCmpInst::Predicate Unused; -+ return CmpClass_match(Unused, L, R); + // Should be called after `compileFromString()`. +@@ -43,24 +38,14 @@ + + // Returns true if code compiles successfully. + // We only parse AST here. This is enough for constexpr evaluation. +-bool compileFromString(StringRef Code, StringRef Standard, StringRef File, +- llvm::StringMap Headers = {}) { ++bool compileFromString(StringRef Code, StringRef Standard, StringRef FileName) { + CompilerInstance Compiler; + Compiler.createDiagnostics(); + +- llvm::IntrusiveRefCntPtr FS( +- new llvm::vfs::InMemoryFileSystem()); +- FS->addFile(File, 0, MemoryBuffer::getMemBuffer(Code)); +- for (const auto &Header : Headers) { +- FS->addFile(Header.getKey(), 0, +- MemoryBuffer::getMemBuffer(Header.getValue())); +- } +- llvm::IntrusiveRefCntPtr Files( +- new FileManager(FileSystemOptions(), FS)); +- Compiler.setFileManager(Files.get()); +- + auto Invocation = std::make_shared(); +- std::vector Args = {Standard.data(), File.data()}; ++ Invocation->getPreprocessorOpts().addRemappedFile( ++ FileName, MemoryBuffer::getMemBuffer(Code).release()); ++ const char *Args[] = {Standard.data(), FileName.data()}; + CompilerInvocation::CreateFromArgs(*Invocation, Args, + Compiler.getDiagnostics()); + Compiler.setInvocation(std::move(Invocation)); +@@ -75,28 +60,13 @@ + return Compiler.ExecuteAction(Action); } - // Same as CmpClass, but instead of saving Pred as out output variable, match a -diff -ruN --strip-trailing-cr a/llvm/include/llvm/TargetParser/Triple.h b/llvm/include/llvm/TargetParser/Triple.h ---- a/llvm/include/llvm/TargetParser/Triple.h -+++ b/llvm/include/llvm/TargetParser/Triple.h -@@ -88,6 +88,8 @@ - xtensa, // Tensilica: Xtensa - nvptx, // NVPTX: 32-bit - nvptx64, // NVPTX: 64-bit -+ le32, // le32: generic little-endian 32-bit CPU (PNaCl) -+ le64, // le64: generic little-endian 64-bit CPU (PNaCl) - amdil, // AMDIL - amdil64, // AMDIL with 64-bit pointers - hsail, // AMD HSAIL -diff -ruN --strip-trailing-cr a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp ---- a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp -+++ b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp -@@ -128,7 +128,6 @@ - bool visitINSviGPR(MachineInstr &MI, unsigned Opc); - bool visitINSvi64lane(MachineInstr &MI); - bool visitFMOVDr(MachineInstr &MI); -- bool visitCopy(MachineInstr &MI); - bool runOnMachineFunction(MachineFunction &MF) override; +-std::string GetMetadata(json::Object *Event) { +- std::string Metadata; +- llvm::raw_string_ostream OS(Metadata); +- if (json::Object *Args = Event->getObject("args")) { +- if (auto Detail = Args->getString("detail")) +- OS << Detail; +- // Use only filename to not include os-specific path separators. +- if (auto File = Args->getString("file")) +- OS << ", " << llvm::sys::path::filename(*File); +- if (auto Line = Args->getInteger("line")) +- OS << ":" << *Line; +- } +- return Metadata; +-} +- + // Returns pretty-printed trace graph. + std::string buildTraceGraph(StringRef Json) { + struct EventRecord { + int64_t TimestampBegin; + int64_t TimestampEnd; +- std::string Name; +- std::string Metadata; ++ StringRef Name; ++ StringRef Detail; + }; + std::vector Events; + +@@ -111,13 +81,10 @@ + int64_t TimestampBegin = TraceEventObj->getInteger("ts").value_or(0); + int64_t TimestampEnd = + TimestampBegin + TraceEventObj->getInteger("dur").value_or(0); +- std::string Name = TraceEventObj->getString("name").value_or("").str(); +- std::string Metadata = GetMetadata(TraceEventObj); +- +- // Source events are asynchronous events and may not perfectly nest the +- // synchronous events. Skip testing them. +- if (Name == "Source") +- continue; ++ StringRef Name = TraceEventObj->getString("name").value_or(""); ++ StringRef Detail = ""; ++ if (json::Object *Args = TraceEventObj->getObject("args")) ++ Detail = Args->getString("detail").value_or(""); + + // This is a "summary" event, like "Total PerformPendingInstantiations", + // skip it +@@ -125,7 +92,7 @@ + continue; - StringRef getPassName() const override { -@@ -691,34 +690,6 @@ - return true; - } + Events.emplace_back( +- EventRecord{TimestampBegin, TimestampEnd, Name, Metadata}); ++ EventRecord{TimestampBegin, TimestampEnd, Name, Detail}); + } --// Across a basic-block we might have in i32 extract from a value that only --// operates on upper bits (for example a sxtw). We can replace the COPY with a --// new version skipping the sxtw. --bool AArch64MIPeepholeOpt::visitCopy(MachineInstr &MI) { -- Register InputReg = MI.getOperand(1).getReg(); -- if (MI.getOperand(1).getSubReg() != AArch64::sub_32 || -- !MRI->hasOneNonDBGUse(InputReg)) -- return false; + // There can be nested events that are very fast, for example: +@@ -165,9 +132,9 @@ + Stream << "| "; + } + Stream.write(Event.Name.data(), Event.Name.size()); +- if (!Event.Metadata.empty()) { ++ if (!Event.Detail.empty()) { + Stream << " ("; +- Stream.write(Event.Metadata.data(), Event.Metadata.size()); ++ Stream.write(Event.Detail.data(), Event.Detail.size()); + Stream << ")"; + } + Stream << "\n"; +@@ -178,7 +145,7 @@ + } // namespace + + TEST(TimeProfilerTest, ConstantEvaluationCxx20) { +- std::string Code = R"( ++ constexpr StringRef Code = R"( + void print(double value); + + namespace slow_namespace { +@@ -208,7 +175,8 @@ + setupProfiler(); + ASSERT_TRUE(compileFromString(Code, "-std=c++20", "test.cc")); + std::string Json = teardownProfiler(); +- ASSERT_EQ(R"( ++ std::string TraceGraph = buildTraceGraph(Json); ++ ASSERT_TRUE(TraceGraph == R"( + Frontend + | ParseDeclarationOrFunctionDefinition (test.cc:2:1) + | ParseDeclarationOrFunctionDefinition (test.cc:6:1) +@@ -234,54 +202,14 @@ + | ParseDeclarationOrFunctionDefinition (test.cc:25:1) + | | EvaluateAsInitializer (slow_init_list) + | PerformPendingInstantiations +-)", +- buildTraceGraph(Json)); +-} - -- MachineInstr *SrcMI = MRI->getUniqueVRegDef(InputReg); -- MachineInstr *CopyMI = SrcMI; -- while (SrcMI && SrcMI->isFullCopy() && -- MRI->hasOneNonDBGUse(SrcMI->getOperand(1).getReg())) -- SrcMI = MRI->getUniqueVRegDef(SrcMI->getOperand(1).getReg()); +-TEST(TimeProfilerTest, TemplateInstantiations) { +- std::string B_H = R"( +- template +- T fooB(T t) { +- return T(); +- } ++)"); + +- #define MacroTemp(x) template void foo##x(T) { T(); } +- )"; - -- if (!SrcMI || SrcMI->getOpcode() != AArch64::SBFMXri || -- SrcMI->getOperand(2).getImm() != 0 || SrcMI->getOperand(3).getImm() != 31) -- return false; +- std::string A_H = R"( +- #include "b.h" - -- Register SrcReg = SrcMI->getOperand(1).getReg(); -- MRI->constrainRegClass(SrcReg, MRI->getRegClass(InputReg)); -- MI.getOperand(1).setReg(SrcReg); -- if (CopyMI != SrcMI) -- CopyMI->eraseFromParent(); -- SrcMI->eraseFromParent(); -- return true; --} +- MacroTemp(MTA) - - bool AArch64MIPeepholeOpt::runOnMachineFunction(MachineFunction &MF) { - if (skipFunction(MF.getFunction())) - return false; -@@ -800,9 +771,6 @@ - case AArch64::FMOVDr: - Changed |= visitFMOVDr(MI); - break; -- case AArch64::COPY: -- Changed |= visitCopy(MI); -- break; - } - } - } -diff -ruN --strip-trailing-cr a/llvm/lib/Target/AArch64/peephole-sxtw.mir b/llvm/lib/Target/AArch64/peephole-sxtw.mir ---- a/llvm/lib/Target/AArch64/peephole-sxtw.mir -+++ b/llvm/lib/Target/AArch64/peephole-sxtw.mir -@@ -1,46 +0,0 @@ --# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py --# RUN: llc -run-pass=aarch64-mi-peephole-opt -o - -mtriple=aarch64-unknown-linux -verify-machineinstrs %s | FileCheck %s +- template +- void fooA(T t) { fooB(t); fooMTA(t); } +- )"; +- std::string Code = R"( +- #include "a.h" +- void user() { fooA(0); } +- )"; - ----- --name: removeSxtw --tracksRegLiveness: true --body: | -- bb.0.entry: -- liveins: $x0 -- ; CHECK-LABEL: name: removeSxtw -- ; CHECK: liveins: $x0 -- ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr64 = COPY $x0 -- ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr32sp = COPY [[COPY]].sub_32 -- ; CHECK-NEXT: [[ADDWri:%[0-9]+]]:gpr32sp = ADDWri [[COPY1]], 1, 0 -- ; CHECK-NEXT: $w0 = COPY [[ADDWri]] -- ; CHECK-NEXT: RET_ReallyLR implicit $w0 -- %0:gpr64 = COPY $x0 -- %1:gpr64 = SBFMXri %0:gpr64, 0, 31 -- %2:gpr32sp = COPY %1.sub_32:gpr64 -- %3:gpr32sp = ADDWri %2:gpr32sp, 1, 0 -- $w0 = COPY %3:gpr32sp -- RET_ReallyLR implicit $w0 --... ----- --name: extraCopy --tracksRegLiveness: true --body: | -- bb.0.entry: -- liveins: $x0 -- ; CHECK-LABEL: name: extraCopy -- ; CHECK: liveins: $x0 -- ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr64 = COPY $x0 -- ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr32sp = COPY [[COPY]].sub_32 -- ; CHECK-NEXT: [[ADDWri:%[0-9]+]]:gpr32sp = ADDWri [[COPY1]], 1, 0 -- ; CHECK-NEXT: $w0 = COPY [[ADDWri]] -- ; CHECK-NEXT: RET_ReallyLR implicit $w0 -- %0:gpr64 = COPY $x0 -- %1:gpr64 = SBFMXri %0:gpr64, 0, 31 -- %2:gpr64all = COPY %1:gpr64 -- %3:gpr32sp = COPY %2.sub_32:gpr64all -- %4:gpr32sp = ADDWri %3:gpr32sp, 1, 0 -- $w0 = COPY %4:gpr32sp -- RET_ReallyLR implicit $w0 --... -diff -ruN --strip-trailing-cr a/llvm/lib/TargetParser/Triple.cpp b/llvm/lib/TargetParser/Triple.cpp ---- a/llvm/lib/TargetParser/Triple.cpp -+++ b/llvm/lib/TargetParser/Triple.cpp -@@ -44,6 +44,8 @@ - case hsail: return "hsail"; - case kalimba: return "kalimba"; - case lanai: return "lanai"; -+ case le32: return "le32"; -+ case le64: return "le64"; - case loongarch32: return "loongarch32"; - case loongarch64: return "loongarch64"; - case m68k: return "m68k"; -@@ -197,6 +199,9 @@ - case nvptx: return "nvvm"; - case nvptx64: return "nvvm"; +- setupProfiler(); +- ASSERT_TRUE(compileFromString(Code, "-std=c++20", "test.cc", +- /*Headers=*/{{"a.h", A_H}, {"b.h", B_H}})); +- std::string Json = teardownProfiler(); +- ASSERT_EQ(R"( +-Frontend +-| ParseFunctionDefinition (fooB) +-| ParseFunctionDefinition (fooMTA) +-| ParseFunctionDefinition (fooA) +-| ParseDeclarationOrFunctionDefinition (test.cc:3:5) +-| | ParseFunctionDefinition (user) +-| PerformPendingInstantiations +-| | InstantiateFunction (fooA, a.h:7) +-| | | InstantiateFunction (fooB, b.h:3) +-| | | InstantiateFunction (fooMTA, a.h:4) +-)", +- buildTraceGraph(Json)); ++ // NOTE: If this test is failing, run this test with ++ // `llvm::errs() << TraceGraph;` and change the assert above. + } -+ case le32: return "le32"; -+ case le64: return "le64"; + TEST(TimeProfilerTest, ConstantEvaluationC99) { +- std::string Code = R"( ++ constexpr StringRef Code = R"( + struct { + short quantval[4]; // 3rd line + } value; +@@ -290,12 +218,15 @@ + setupProfiler(); + ASSERT_TRUE(compileFromString(Code, "-std=c99", "test.c")); + std::string Json = teardownProfiler(); +- ASSERT_EQ(R"( ++ std::string TraceGraph = buildTraceGraph(Json); ++ ASSERT_TRUE(TraceGraph == R"( + Frontend + | ParseDeclarationOrFunctionDefinition (test.c:2:1) + | | isIntegerConstantExpr () + | | EvaluateKnownConstIntCheckOverflow () + | PerformPendingInstantiations +-)", +- buildTraceGraph(Json)); ++)"); + - case amdil: - case amdil64: return "amdil"; - -@@ -427,6 +432,8 @@ - .Case("xcore", xcore) - .Case("nvptx", nvptx) - .Case("nvptx64", nvptx64) -+ .Case("le32", le32) -+ .Case("le64", le64) - .Case("amdil", amdil) - .Case("amdil64", amdil64) - .Case("hsail", hsail) -@@ -567,6 +574,8 @@ - .Case("xcore", Triple::xcore) - .Case("nvptx", Triple::nvptx) - .Case("nvptx64", Triple::nvptx64) -+ .Case("le32", Triple::le32) -+ .Case("le64", Triple::le64) - .Case("amdil", Triple::amdil) - .Case("amdil64", Triple::amdil64) - .Case("hsail", Triple::hsail) -@@ -896,6 +905,8 @@ - case Triple::hsail: - case Triple::kalimba: - case Triple::lanai: -+ case Triple::le32: -+ case Triple::le64: - case Triple::loongarch32: - case Triple::loongarch64: - case Triple::m68k: -@@ -1592,6 +1603,7 @@ - case llvm::Triple::hsail: - case llvm::Triple::kalimba: - case llvm::Triple::lanai: -+ case llvm::Triple::le32: - case llvm::Triple::loongarch32: - case llvm::Triple::m68k: - case llvm::Triple::mips: -@@ -1624,6 +1636,7 @@ - case llvm::Triple::bpfeb: - case llvm::Triple::bpfel: - case llvm::Triple::hsail64: -+ case llvm::Triple::le64: - case llvm::Triple::loongarch64: - case llvm::Triple::mips64: - case llvm::Triple::mips64el: -@@ -1682,6 +1695,7 @@ - case Triple::hsail: - case Triple::kalimba: - case Triple::lanai: -+ case Triple::le32: - case Triple::loongarch32: - case Triple::m68k: - case Triple::mips: -@@ -1712,6 +1726,7 @@ - case Triple::aarch64_be: T.setArch(Triple::armeb); break; - case Triple::amdil64: T.setArch(Triple::amdil); break; - case Triple::hsail64: T.setArch(Triple::hsail); break; -+ case Triple::le64: T.setArch(Triple::le32); break; - case Triple::loongarch64: T.setArch(Triple::loongarch32); break; - case Triple::mips64: - T.setArch(Triple::mips, getSubArch()); -@@ -1766,6 +1781,7 @@ - case Triple::bpfeb: - case Triple::bpfel: - case Triple::hsail64: -+ case Triple::le64: - case Triple::loongarch64: - case Triple::mips64: - case Triple::mips64el: -@@ -1789,6 +1805,7 @@ - case Triple::arm: T.setArch(Triple::aarch64); break; - case Triple::armeb: T.setArch(Triple::aarch64_be); break; - case Triple::hsail: T.setArch(Triple::hsail64); break; -+ case Triple::le32: T.setArch(Triple::le64); break; - case Triple::loongarch32: T.setArch(Triple::loongarch64); break; - case Triple::mips: - T.setArch(Triple::mips64, getSubArch()); -@@ -1831,6 +1848,8 @@ - case Triple::hsail64: - case Triple::hsail: - case Triple::kalimba: -+ case Triple::le32: -+ case Triple::le64: - case Triple::loongarch32: - case Triple::loongarch64: - case Triple::msp430: -@@ -1934,6 +1953,8 @@ - case Triple::hsail64: - case Triple::hsail: - case Triple::kalimba: -+ case Triple::le32: -+ case Triple::le64: - case Triple::loongarch32: - case Triple::loongarch64: - case Triple::mips64el: -diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp ---- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp -+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp -@@ -6995,7 +6995,7 @@ - // Ignore ephemeral values. - CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore); ++ // NOTE: If this test is failing, run this test with ++ // `llvm::errs() << TraceGraph;` and change the assert above. + } +diff -ruN --strip-trailing-cr a/lld/test/MachO/reproduce-thin-archive-objc.s b/lld/test/MachO/reproduce-thin-archive-objc.s +--- a/lld/test/MachO/reproduce-thin-archive-objc.s ++++ b/lld/test/MachO/reproduce-thin-archive-objc.s +@@ -4,20 +4,19 @@ + ## during linking. However, we need to iterate over all members for -ObjC, check that we don't + ## crash when we encounter a missing member. + +-# RUN: rm -rf %t; mkdir %t +-# RUN: sed s/SYM/_main/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o %t/main.o +-# RUN: sed s/SYM/_unused/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o %t/unused.o ++# RUN: rm -rf %t && mkdir %t && cd %t ++# RUN: sed s/SYM/_main/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o main.o ++# RUN: sed s/SYM/_unused/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o unused.o + +-# RUN: cd %t; llvm-ar rcsT unused.a unused.o; rm unused.o ++# RUN: llvm-ar rcsT unused.a unused.o; rm unused.o + ## FIXME: Absolute paths don't end up relativized in the repro file. + + # RUN: %no-fatal-warnings-lld %t/main.o %t/unused.a -ObjC -o /dev/null 2>&1 \ + # RUN: | FileCheck %s --check-prefix=WARN + +-# RUN: %lld %t/main.o %t/unused.a -ObjC --no-warn-thin-archive-missing-members -o /dev/null \ +-# RUN: | FileCheck %s --implicit-check-not 'warning' --allow-empty ++# RUN: %lld main.o unused.a -ObjC --no-warn-thin-archive-missing-members 2>&1 | count 0 + +-# WARN: ld64.lld: warning: {{.*}}unused.a: -ObjC failed to open archive member: 'unused.o' ++# WARN: warning: {{.*}}unused.a: -ObjC failed to open archive member: 'unused.o' + + .text + .globl SYM +diff -ruN --strip-trailing-cr a/llvm/include/llvm/Support/TimeProfiler.h b/llvm/include/llvm/Support/TimeProfiler.h +--- a/llvm/include/llvm/Support/TimeProfiler.h ++++ b/llvm/include/llvm/Support/TimeProfiler.h +@@ -83,28 +83,16 @@ + + class raw_pwrite_stream; + +-struct TimeTraceMetadata { +- std::string Detail; +- // Source file and line number information for the event. +- std::string File; +- int Line; +- +- bool isEmpty() const { return Detail.empty() && File.empty(); } +-}; +- + struct TimeTraceProfiler; + TimeTraceProfiler *getTimeTraceProfilerInstance(); -- SmallSetVector DeadInterleavePointerOps; -+ SmallVector DeadInterleavePointerOps; - for (BasicBlock *BB : TheLoop->blocks()) - for (Instruction &I : *BB) { - // Find all stores to invariant variables. Since they are going to sink -@@ -7013,7 +7013,7 @@ - if (Group->getInsertPos() == &I) - continue; - Value *PointerOp = getLoadStorePointerOperand(&I); -- DeadInterleavePointerOps.insert(PointerOp); -+ DeadInterleavePointerOps.push_back(PointerOp); - } - } +-bool isTimeTraceVerbose(); +- + struct TimeTraceProfilerEntry; + + /// Initialize the time trace profiler. + /// This sets up the global \p TimeTraceProfilerInstance + /// variable to be the profiler instance. + void timeTraceProfilerInitialize(unsigned TimeTraceGranularity, +- StringRef ProcName, +- bool TimeTraceVerbose = false); ++ StringRef ProcName); + + /// Cleanup the time trace profiler, if it was initialized. + void timeTraceProfilerCleanup(); +@@ -140,10 +128,6 @@ + timeTraceProfilerBegin(StringRef Name, + llvm::function_ref Detail); + +-TimeTraceProfilerEntry * +-timeTraceProfilerBegin(StringRef Name, +- llvm::function_ref MetaData); +- + /// Manually begin a time section, with the given \p Name and \p Detail. + /// This starts Async Events having \p Name as a category which is shown + /// separately from other traces. See +@@ -180,11 +164,6 @@ + if (getTimeTraceProfilerInstance() != nullptr) + Entry = timeTraceProfilerBegin(Name, Detail); + } +- TimeTraceScope(StringRef Name, +- llvm::function_ref Metadata) { +- if (getTimeTraceProfilerInstance() != nullptr) +- Entry = timeTraceProfilerBegin(Name, Metadata); +- } + ~TimeTraceScope() { + if (getTimeTraceProfilerInstance() != nullptr) + timeTraceProfilerEnd(Entry); +diff -ruN --strip-trailing-cr a/llvm/lib/Support/TimeProfiler.cpp b/llvm/lib/Support/TimeProfiler.cpp +--- a/llvm/lib/Support/TimeProfiler.cpp ++++ b/llvm/lib/Support/TimeProfiler.cpp +@@ -73,20 +73,12 @@ + const TimePointType Start; + TimePointType End; + const std::string Name; +- TimeTraceMetadata Metadata; +- ++ const std::string Detail; + const bool AsyncEvent = false; + TimeTraceProfilerEntry(TimePointType &&S, TimePointType &&E, std::string &&N, + std::string &&Dt, bool Ae) +- : Start(std::move(S)), End(std::move(E)), Name(std::move(N)), Metadata(), +- AsyncEvent(Ae) { +- Metadata.Detail = std::move(Dt); +- } +- +- TimeTraceProfilerEntry(TimePointType &&S, TimePointType &&E, std::string &&N, +- TimeTraceMetadata &&Mt, bool Ae) + : Start(std::move(S)), End(std::move(E)), Name(std::move(N)), +- Metadata(std::move(Mt)), AsyncEvent(Ae) {} ++ Detail(std::move(Dt)), AsyncEvent(Ae) {} + + // Calculate timings for FlameGraph. Cast time points to microsecond precision + // rather than casting duration. This avoids truncation issues causing inner +@@ -105,12 +97,10 @@ + }; -@@ -7029,7 +7029,7 @@ - })) - continue; - VecValuesToIgnore.insert(Op); -- DeadInterleavePointerOps.insert(Op->op_begin(), Op->op_end()); -+ DeadInterleavePointerOps.append(Op->op_begin(), Op->op_end()); + struct llvm::TimeTraceProfiler { +- TimeTraceProfiler(unsigned TimeTraceGranularity = 0, StringRef ProcName = "", +- bool TimeTraceVerbose = false) ++ TimeTraceProfiler(unsigned TimeTraceGranularity = 0, StringRef ProcName = "") + : BeginningOfTime(system_clock::now()), StartTime(ClockType::now()), + ProcName(ProcName), Pid(sys::Process::getProcessId()), +- Tid(llvm::get_threadid()), TimeTraceGranularity(TimeTraceGranularity), +- TimeTraceVerbose(TimeTraceVerbose) { ++ Tid(llvm::get_threadid()), TimeTraceGranularity(TimeTraceGranularity) { + llvm::get_thread_name(ThreadName); } - // Ignore type-promoting instructions we identified during reduction -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll b/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll ---- a/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll -+++ b/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll -@@ -281,7 +281,8 @@ - ; CHECK-LABEL: smull_ldrsw_shift: - ; CHECK: // %bb.0: // %entry - ; CHECK-NEXT: ldrsw x8, [x0] --; CHECK-NEXT: smull x0, w8, w1 -+; CHECK-NEXT: sxtw x9, w1 -+; CHECK-NEXT: smull x0, w8, w9 - ; CHECK-NEXT: ret - entry: - %ext64 = load i32, ptr %x0 -@@ -489,7 +490,8 @@ - ; CHECK-LABEL: smaddl_ldrsw_shift: - ; CHECK: // %bb.0: // %entry - ; CHECK-NEXT: ldrsw x8, [x0] --; CHECK-NEXT: smaddl x0, w8, w1, x2 -+; CHECK-NEXT: sxtw x9, w1 -+; CHECK-NEXT: smaddl x0, w8, w9, x2 - ; CHECK-NEXT: ret - entry: - %ext64 = load i32, ptr %x0 -@@ -652,7 +654,8 @@ - ; CHECK-LABEL: smnegl_ldrsw_shift: - ; CHECK: // %bb.0: // %entry - ; CHECK-NEXT: ldrsw x8, [x0] --; CHECK-NEXT: smnegl x0, w8, w1 -+; CHECK-NEXT: sxtw x9, w1 -+; CHECK-NEXT: smnegl x0, w8, w9 - ; CHECK-NEXT: ret - entry: - %ext64 = load i32, ptr %x0 -@@ -815,7 +818,8 @@ - ; CHECK-LABEL: smsubl_ldrsw_shift: - ; CHECK: // %bb.0: // %entry - ; CHECK-NEXT: ldrsw x8, [x0] --; CHECK-NEXT: smsubl x0, w8, w1, x2 -+; CHECK-NEXT: sxtw x9, w1 -+; CHECK-NEXT: smsubl x0, w8, w9, x2 - ; CHECK-NEXT: ret - entry: - %ext64 = load i32, ptr %x0 -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll b/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll ---- a/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll -+++ b/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll -@@ -182,9 +182,432 @@ - exit: - ret void - } -+ -+define void @geps_feeding_interleave_groups_with_reuse(ptr %arg, i64 %arg1, ptr %arg2) #0 { -+; CHECK-LABEL: define void @geps_feeding_interleave_groups_with_reuse( -+; CHECK-SAME: ptr [[ARG:%.*]], i64 [[ARG1:%.*]], ptr [[ARG2:%.*]]) #[[ATTR0:[0-9]+]] { -+; CHECK-NEXT: [[ENTRY:.*]]: -+; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[ARG1]], 1 -+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 30 -+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_SCEVCHECK:.*]] -+; CHECK: [[VECTOR_SCEVCHECK]]: -+; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[ARG2]], i64 8 -+; CHECK-NEXT: [[MUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 16, i64 [[ARG1]]) -+; CHECK-NEXT: [[MUL_RESULT:%.*]] = extractvalue { i64, i1 } [[MUL]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW:%.*]] = extractvalue { i64, i1 } [[MUL]], 1 -+; CHECK-NEXT: [[TMP1:%.*]] = sub i64 0, [[MUL_RESULT]] -+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[SCEVGEP]], i64 [[MUL_RESULT]] -+; CHECK-NEXT: [[TMP3:%.*]] = icmp ult ptr [[TMP2]], [[SCEVGEP]] -+; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[MUL_OVERFLOW]] -+; CHECK-NEXT: [[SCEVGEP1:%.*]] = getelementptr i8, ptr [[ARG2]], i64 12 -+; CHECK-NEXT: [[MUL2:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 16, i64 [[ARG1]]) -+; CHECK-NEXT: [[MUL_RESULT3:%.*]] = extractvalue { i64, i1 } [[MUL2]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW4:%.*]] = extractvalue { i64, i1 } [[MUL2]], 1 -+; CHECK-NEXT: [[TMP5:%.*]] = sub i64 0, [[MUL_RESULT3]] -+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[SCEVGEP1]], i64 [[MUL_RESULT3]] -+; CHECK-NEXT: [[TMP7:%.*]] = icmp ult ptr [[TMP6]], [[SCEVGEP1]] -+; CHECK-NEXT: [[TMP8:%.*]] = or i1 [[TMP7]], [[MUL_OVERFLOW4]] -+; CHECK-NEXT: [[SCEVGEP5:%.*]] = getelementptr i8, ptr [[ARG2]], i64 4 -+; CHECK-NEXT: [[MUL6:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 16, i64 [[ARG1]]) -+; CHECK-NEXT: [[MUL_RESULT7:%.*]] = extractvalue { i64, i1 } [[MUL6]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW8:%.*]] = extractvalue { i64, i1 } [[MUL6]], 1 -+; CHECK-NEXT: [[TMP9:%.*]] = sub i64 0, [[MUL_RESULT7]] -+; CHECK-NEXT: [[TMP10:%.*]] = getelementptr i8, ptr [[SCEVGEP5]], i64 [[MUL_RESULT7]] -+; CHECK-NEXT: [[TMP11:%.*]] = icmp ult ptr [[TMP10]], [[SCEVGEP5]] -+; CHECK-NEXT: [[TMP12:%.*]] = or i1 [[TMP11]], [[MUL_OVERFLOW8]] -+; CHECK-NEXT: [[MUL9:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 16, i64 [[ARG1]]) -+; CHECK-NEXT: [[MUL_RESULT10:%.*]] = extractvalue { i64, i1 } [[MUL9]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW11:%.*]] = extractvalue { i64, i1 } [[MUL9]], 1 -+; CHECK-NEXT: [[TMP13:%.*]] = sub i64 0, [[MUL_RESULT10]] -+; CHECK-NEXT: [[TMP14:%.*]] = getelementptr i8, ptr [[ARG2]], i64 [[MUL_RESULT10]] -+; CHECK-NEXT: [[TMP15:%.*]] = icmp ult ptr [[TMP14]], [[ARG2]] -+; CHECK-NEXT: [[TMP16:%.*]] = or i1 [[TMP15]], [[MUL_OVERFLOW11]] -+; CHECK-NEXT: [[TMP17:%.*]] = or i1 [[TMP4]], [[TMP8]] -+; CHECK-NEXT: [[TMP18:%.*]] = or i1 [[TMP17]], [[TMP12]] -+; CHECK-NEXT: [[TMP19:%.*]] = or i1 [[TMP18]], [[TMP16]] -+; CHECK-NEXT: br i1 [[TMP19]], label %[[SCALAR_PH]], label %[[VECTOR_MEMCHECK:.*]] -+; CHECK: [[VECTOR_MEMCHECK]]: -+; CHECK-NEXT: [[TMP20:%.*]] = shl i64 [[ARG1]], 4 -+; CHECK-NEXT: [[TMP21:%.*]] = add i64 [[TMP20]], 16 -+; CHECK-NEXT: [[SCEVGEP12:%.*]] = getelementptr i8, ptr [[ARG2]], i64 [[TMP21]] -+; CHECK-NEXT: [[TMP22:%.*]] = shl i64 [[ARG1]], 5 -+; CHECK-NEXT: [[TMP23:%.*]] = add i64 [[TMP22]], 32 -+; CHECK-NEXT: [[SCEVGEP13:%.*]] = getelementptr i8, ptr [[ARG]], i64 [[TMP23]] -+; CHECK-NEXT: [[BOUND0:%.*]] = icmp ult ptr [[ARG2]], [[SCEVGEP13]] -+; CHECK-NEXT: [[BOUND1:%.*]] = icmp ult ptr [[ARG]], [[SCEVGEP12]] -+; CHECK-NEXT: [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]] -+; CHECK-NEXT: br i1 [[FOUND_CONFLICT]], label %[[SCALAR_PH]], label %[[VECTOR_PH:.*]] -+; CHECK: [[VECTOR_PH]]: -+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 2 -+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]] -+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] -+; CHECK: [[VECTOR_BODY]]: -+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -+; CHECK-NEXT: [[TMP24:%.*]] = add i64 [[INDEX]], 0 -+; CHECK-NEXT: [[TMP25:%.*]] = shl i64 [[TMP24]], 5 -+; CHECK-NEXT: [[TMP26:%.*]] = getelementptr i8, ptr [[ARG]], i64 [[TMP25]] -+; CHECK-NEXT: [[TMP27:%.*]] = shl i64 [[TMP24]], 4 -+; CHECK-NEXT: [[TMP28:%.*]] = getelementptr i8, ptr [[ARG2]], i64 [[TMP27]] -+; CHECK-NEXT: [[TMP29:%.*]] = getelementptr float, ptr [[TMP26]], i32 0 -+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x float>, ptr [[TMP29]], align 4 -+; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC14:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC15:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC16:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC17:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC18:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC19:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC20:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[TMP30:%.*]] = fadd <2 x float> [[STRIDED_VEC]], [[STRIDED_VEC17]] -+; CHECK-NEXT: [[TMP31:%.*]] = fmul <2 x float> [[TMP30]], zeroinitializer -+; CHECK-NEXT: [[TMP32:%.*]] = fadd <2 x float> [[STRIDED_VEC14]], [[STRIDED_VEC18]] -+; CHECK-NEXT: [[TMP33:%.*]] = fmul <2 x float> [[TMP32]], zeroinitializer -+; CHECK-NEXT: [[TMP34:%.*]] = fadd <2 x float> [[STRIDED_VEC15]], [[STRIDED_VEC19]] -+; CHECK-NEXT: [[TMP35:%.*]] = fmul <2 x float> [[TMP34]], zeroinitializer -+; CHECK-NEXT: [[TMP36:%.*]] = fadd <2 x float> [[STRIDED_VEC16]], [[STRIDED_VEC20]] -+; CHECK-NEXT: [[TMP37:%.*]] = fmul <2 x float> [[TMP36]], zeroinitializer -+; CHECK-NEXT: [[TMP38:%.*]] = getelementptr i8, ptr [[TMP28]], i64 12 -+; CHECK-NEXT: [[TMP39:%.*]] = getelementptr float, ptr [[TMP38]], i32 -3 -+; CHECK-NEXT: [[TMP40:%.*]] = shufflevector <2 x float> [[TMP31]], <2 x float> [[TMP33]], <4 x i32> -+; CHECK-NEXT: [[TMP41:%.*]] = shufflevector <2 x float> [[TMP35]], <2 x float> [[TMP37]], <4 x i32> -+; CHECK-NEXT: [[TMP42:%.*]] = shufflevector <4 x float> [[TMP40]], <4 x float> [[TMP41]], <8 x i32> -+; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <8 x float> [[TMP42]], <8 x float> poison, <8 x i32> -+; CHECK-NEXT: store <8 x float> [[INTERLEAVED_VEC]], ptr [[TMP39]], align 4 -+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2 -+; CHECK-NEXT: [[TMP43:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -+; CHECK-NEXT: br i1 [[TMP43]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] -+; CHECK: [[MIDDLE_BLOCK]]: -+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]] -+; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]] -+; CHECK: [[SCALAR_PH]]: -+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ], [ 0, %[[VECTOR_SCEVCHECK]] ], [ 0, %[[VECTOR_MEMCHECK]] ] -+; CHECK-NEXT: br label %[[LOOP:.*]] -+; CHECK: [[LOOP]]: -+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] -+; CHECK-NEXT: [[SHL_IV_5:%.*]] = shl i64 [[IV]], 5 -+; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr i8, ptr [[ARG]], i64 [[SHL_IV_5]] -+; CHECK-NEXT: [[ADD_5:%.*]] = or disjoint i64 [[SHL_IV_5]], 16 -+; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr i8, ptr [[ARG]], i64 [[ADD_5]] -+; CHECK-NEXT: [[SHL_IV_4:%.*]] = shl i64 [[IV]], 4 -+; CHECK-NEXT: [[GEP_3:%.*]] = getelementptr i8, ptr [[ARG2]], i64 [[SHL_IV_4]] -+; CHECK-NEXT: [[L_1:%.*]] = load float, ptr [[GEP_1]], align 4 -+; CHECK-NEXT: [[L_2:%.*]] = load float, ptr [[GEP_2]], align 4 -+; CHECK-NEXT: [[ADD_1:%.*]] = fadd float [[L_1]], [[L_2]] -+; CHECK-NEXT: [[MUL_1:%.*]] = fmul float [[ADD_1]], 0.000000e+00 -+; CHECK-NEXT: store float [[MUL_1]], ptr [[GEP_3]], align 4 -+; CHECK-NEXT: [[GEP_4:%.*]] = getelementptr i8, ptr [[GEP_1]], i64 4 -+; CHECK-NEXT: [[L_3:%.*]] = load float, ptr [[GEP_4]], align 4 -+; CHECK-NEXT: [[GEP_5:%.*]] = getelementptr i8, ptr [[GEP_2]], i64 4 -+; CHECK-NEXT: [[L_4:%.*]] = load float, ptr [[GEP_5]], align 4 -+; CHECK-NEXT: [[ADD_2:%.*]] = fadd float [[L_3]], [[L_4]] -+; CHECK-NEXT: [[MUL_2:%.*]] = fmul float [[ADD_2]], 0.000000e+00 -+; CHECK-NEXT: [[GEP_6:%.*]] = getelementptr i8, ptr [[GEP_3]], i64 4 -+; CHECK-NEXT: store float [[MUL_2]], ptr [[GEP_6]], align 4 -+; CHECK-NEXT: [[GEP_7:%.*]] = getelementptr i8, ptr [[GEP_1]], i64 8 -+; CHECK-NEXT: [[L_5:%.*]] = load float, ptr [[GEP_7]], align 4 -+; CHECK-NEXT: [[GEP_8:%.*]] = getelementptr i8, ptr [[GEP_2]], i64 8 -+; CHECK-NEXT: [[L_6:%.*]] = load float, ptr [[GEP_8]], align 4 -+; CHECK-NEXT: [[ADD_3:%.*]] = fadd float [[L_5]], [[L_6]] -+; CHECK-NEXT: [[MUL_3:%.*]] = fmul float [[ADD_3]], 0.000000e+00 -+; CHECK-NEXT: [[GEP_9:%.*]] = getelementptr i8, ptr [[GEP_3]], i64 8 -+; CHECK-NEXT: store float [[MUL_3]], ptr [[GEP_9]], align 4 -+; CHECK-NEXT: [[I27:%.*]] = getelementptr i8, ptr [[GEP_1]], i64 12 -+; CHECK-NEXT: [[L_7:%.*]] = load float, ptr [[I27]], align 4 -+; CHECK-NEXT: [[GEP_10:%.*]] = getelementptr i8, ptr [[GEP_2]], i64 12 -+; CHECK-NEXT: [[L_8:%.*]] = load float, ptr [[GEP_10]], align 4 -+; CHECK-NEXT: [[ADD_4:%.*]] = fadd float [[L_7]], [[L_8]] -+; CHECK-NEXT: [[MUL_4:%.*]] = fmul float [[ADD_4]], 0.000000e+00 -+; CHECK-NEXT: [[GEP_11:%.*]] = getelementptr i8, ptr [[GEP_3]], i64 12 -+; CHECK-NEXT: store float [[MUL_4]], ptr [[GEP_11]], align 4 -+; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1 -+; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV]], [[ARG1]] -+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] -+; CHECK: [[EXIT]]: -+; CHECK-NEXT: ret void -+; -+entry: -+ br label %loop -+ -+loop: -+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] -+ %shl.iv.5 = shl i64 %iv, 5 -+ %gep.1 = getelementptr i8, ptr %arg, i64 %shl.iv.5 -+ %add.5 = or disjoint i64 %shl.iv.5, 16 -+ %gep.2 = getelementptr i8, ptr %arg, i64 %add.5 -+ %shl.iv.4 = shl i64 %iv, 4 -+ %gep.3 = getelementptr i8, ptr %arg2, i64 %shl.iv.4 -+ %l.1 = load float, ptr %gep.1, align 4 -+ %l.2 = load float, ptr %gep.2, align 4 -+ %add.1 = fadd float %l.1, %l.2 -+ %mul.1 = fmul float %add.1, 0.000000e+00 -+ store float %mul.1, ptr %gep.3, align 4 -+ %gep.4 = getelementptr i8, ptr %gep.1, i64 4 -+ %l.3 = load float, ptr %gep.4, align 4 -+ %gep.5 = getelementptr i8, ptr %gep.2, i64 4 -+ %l.4 = load float, ptr %gep.5, align 4 -+ %add.2 = fadd float %l.3, %l.4 -+ %mul.2 = fmul float %add.2, 0.000000e+00 -+ %gep.6 = getelementptr i8, ptr %gep.3, i64 4 -+ store float %mul.2, ptr %gep.6, align 4 -+ %gep.7 = getelementptr i8, ptr %gep.1, i64 8 -+ %l.5 = load float, ptr %gep.7, align 4 -+ %gep.8 = getelementptr i8, ptr %gep.2, i64 8 -+ %l.6 = load float, ptr %gep.8, align 4 -+ %add.3 = fadd float %l.5, %l.6 -+ %mul.3 = fmul float %add.3, 0.000000e+00 -+ %gep.9 = getelementptr i8, ptr %gep.3, i64 8 -+ store float %mul.3, ptr %gep.9, align 4 -+ %i27 = getelementptr i8, ptr %gep.1, i64 12 -+ %l.7 = load float, ptr %i27, align 4 -+ %gep.10 = getelementptr i8, ptr %gep.2, i64 12 -+ %l.8 = load float, ptr %gep.10, align 4 -+ %add.4 = fadd float %l.7, %l.8 -+ %mul.4 = fmul float %add.4, 0.000000e+00 -+ %gep.11 = getelementptr i8, ptr %gep.3, i64 12 -+ store float %mul.4, ptr %gep.11, align 4 -+ %iv.next = add i64 %iv, 1 -+ %ec = icmp eq i64 %iv, %arg1 -+ br i1 %ec, label %exit, label %loop -+ -+exit: -+ ret void -+} -+ -+define void @geps_feeding_interleave_groups_with_reuse2(ptr %A, ptr %B, i64 %N) #1 { -+; CHECK-LABEL: define void @geps_feeding_interleave_groups_with_reuse2( -+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]], i64 [[N:%.*]]) #[[ATTR1:[0-9]+]] { -+; CHECK-NEXT: [[ENTRY:.*]]: -+; CHECK-NEXT: [[TMP0:%.*]] = lshr i64 [[N]], 3 -+; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[TMP0]], 1 -+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 28 -+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_SCEVCHECK:.*]] -+; CHECK: [[VECTOR_SCEVCHECK]]: -+; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[N]], 3 -+; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[A]], i64 24 -+; CHECK-NEXT: [[MUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT:%.*]] = extractvalue { i64, i1 } [[MUL]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW:%.*]] = extractvalue { i64, i1 } [[MUL]], 1 -+; CHECK-NEXT: [[TMP3:%.*]] = sub i64 0, [[MUL_RESULT]] -+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[SCEVGEP]], i64 [[MUL_RESULT]] -+; CHECK-NEXT: [[TMP5:%.*]] = icmp ult ptr [[TMP4]], [[SCEVGEP]] -+; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP5]], [[MUL_OVERFLOW]] -+; CHECK-NEXT: [[SCEVGEP1:%.*]] = getelementptr i8, ptr [[A]], i64 28 -+; CHECK-NEXT: [[MUL2:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT3:%.*]] = extractvalue { i64, i1 } [[MUL2]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW4:%.*]] = extractvalue { i64, i1 } [[MUL2]], 1 -+; CHECK-NEXT: [[TMP7:%.*]] = sub i64 0, [[MUL_RESULT3]] -+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[SCEVGEP1]], i64 [[MUL_RESULT3]] -+; CHECK-NEXT: [[TMP9:%.*]] = icmp ult ptr [[TMP8]], [[SCEVGEP1]] -+; CHECK-NEXT: [[TMP10:%.*]] = or i1 [[TMP9]], [[MUL_OVERFLOW4]] -+; CHECK-NEXT: [[SCEVGEP5:%.*]] = getelementptr i8, ptr [[A]], i64 20 -+; CHECK-NEXT: [[MUL6:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT7:%.*]] = extractvalue { i64, i1 } [[MUL6]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW8:%.*]] = extractvalue { i64, i1 } [[MUL6]], 1 -+; CHECK-NEXT: [[TMP11:%.*]] = sub i64 0, [[MUL_RESULT7]] -+; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i8, ptr [[SCEVGEP5]], i64 [[MUL_RESULT7]] -+; CHECK-NEXT: [[TMP13:%.*]] = icmp ult ptr [[TMP12]], [[SCEVGEP5]] -+; CHECK-NEXT: [[TMP14:%.*]] = or i1 [[TMP13]], [[MUL_OVERFLOW8]] -+; CHECK-NEXT: [[SCEVGEP9:%.*]] = getelementptr i8, ptr [[A]], i64 16 -+; CHECK-NEXT: [[MUL10:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT11:%.*]] = extractvalue { i64, i1 } [[MUL10]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW12:%.*]] = extractvalue { i64, i1 } [[MUL10]], 1 -+; CHECK-NEXT: [[TMP15:%.*]] = sub i64 0, [[MUL_RESULT11]] -+; CHECK-NEXT: [[TMP16:%.*]] = getelementptr i8, ptr [[SCEVGEP9]], i64 [[MUL_RESULT11]] -+; CHECK-NEXT: [[TMP17:%.*]] = icmp ult ptr [[TMP16]], [[SCEVGEP9]] -+; CHECK-NEXT: [[TMP18:%.*]] = or i1 [[TMP17]], [[MUL_OVERFLOW12]] -+; CHECK-NEXT: [[SCEVGEP13:%.*]] = getelementptr i8, ptr [[A]], i64 12 -+; CHECK-NEXT: [[MUL14:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT15:%.*]] = extractvalue { i64, i1 } [[MUL14]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW16:%.*]] = extractvalue { i64, i1 } [[MUL14]], 1 -+; CHECK-NEXT: [[TMP19:%.*]] = sub i64 0, [[MUL_RESULT15]] -+; CHECK-NEXT: [[TMP20:%.*]] = getelementptr i8, ptr [[SCEVGEP13]], i64 [[MUL_RESULT15]] -+; CHECK-NEXT: [[TMP21:%.*]] = icmp ult ptr [[TMP20]], [[SCEVGEP13]] -+; CHECK-NEXT: [[TMP22:%.*]] = or i1 [[TMP21]], [[MUL_OVERFLOW16]] -+; CHECK-NEXT: [[SCEVGEP17:%.*]] = getelementptr i8, ptr [[A]], i64 8 -+; CHECK-NEXT: [[MUL18:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT19:%.*]] = extractvalue { i64, i1 } [[MUL18]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW20:%.*]] = extractvalue { i64, i1 } [[MUL18]], 1 -+; CHECK-NEXT: [[TMP23:%.*]] = sub i64 0, [[MUL_RESULT19]] -+; CHECK-NEXT: [[TMP24:%.*]] = getelementptr i8, ptr [[SCEVGEP17]], i64 [[MUL_RESULT19]] -+; CHECK-NEXT: [[TMP25:%.*]] = icmp ult ptr [[TMP24]], [[SCEVGEP17]] -+; CHECK-NEXT: [[TMP26:%.*]] = or i1 [[TMP25]], [[MUL_OVERFLOW20]] -+; CHECK-NEXT: [[SCEVGEP21:%.*]] = getelementptr i8, ptr [[A]], i64 4 -+; CHECK-NEXT: [[MUL22:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT23:%.*]] = extractvalue { i64, i1 } [[MUL22]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW24:%.*]] = extractvalue { i64, i1 } [[MUL22]], 1 -+; CHECK-NEXT: [[TMP27:%.*]] = sub i64 0, [[MUL_RESULT23]] -+; CHECK-NEXT: [[TMP28:%.*]] = getelementptr i8, ptr [[SCEVGEP21]], i64 [[MUL_RESULT23]] -+; CHECK-NEXT: [[TMP29:%.*]] = icmp ult ptr [[TMP28]], [[SCEVGEP21]] -+; CHECK-NEXT: [[TMP30:%.*]] = or i1 [[TMP29]], [[MUL_OVERFLOW24]] -+; CHECK-NEXT: [[MUL25:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT26:%.*]] = extractvalue { i64, i1 } [[MUL25]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW27:%.*]] = extractvalue { i64, i1 } [[MUL25]], 1 -+; CHECK-NEXT: [[TMP31:%.*]] = sub i64 0, [[MUL_RESULT26]] -+; CHECK-NEXT: [[TMP32:%.*]] = getelementptr i8, ptr [[A]], i64 [[MUL_RESULT26]] -+; CHECK-NEXT: [[TMP33:%.*]] = icmp ult ptr [[TMP32]], [[A]] -+; CHECK-NEXT: [[TMP34:%.*]] = or i1 [[TMP33]], [[MUL_OVERFLOW27]] -+; CHECK-NEXT: [[TMP35:%.*]] = or i1 [[TMP6]], [[TMP10]] -+; CHECK-NEXT: [[TMP36:%.*]] = or i1 [[TMP35]], [[TMP14]] -+; CHECK-NEXT: [[TMP37:%.*]] = or i1 [[TMP36]], [[TMP18]] -+; CHECK-NEXT: [[TMP38:%.*]] = or i1 [[TMP37]], [[TMP22]] -+; CHECK-NEXT: [[TMP39:%.*]] = or i1 [[TMP38]], [[TMP26]] -+; CHECK-NEXT: [[TMP40:%.*]] = or i1 [[TMP39]], [[TMP30]] -+; CHECK-NEXT: [[TMP41:%.*]] = or i1 [[TMP40]], [[TMP34]] -+; CHECK-NEXT: br i1 [[TMP41]], label %[[SCALAR_PH]], label %[[VECTOR_MEMCHECK:.*]] -+; CHECK: [[VECTOR_MEMCHECK]]: -+; CHECK-NEXT: [[TMP42:%.*]] = lshr i64 [[N]], 3 -+; CHECK-NEXT: [[TMP43:%.*]] = shl i64 [[TMP42]], 5 -+; CHECK-NEXT: [[TMP44:%.*]] = add i64 [[TMP43]], 32 -+; CHECK-NEXT: [[SCEVGEP28:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP44]] -+; CHECK-NEXT: [[TMP45:%.*]] = add nuw nsw i64 [[TMP43]], 4 -+; CHECK-NEXT: [[SCEVGEP29:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP45]] -+; CHECK-NEXT: [[TMP46:%.*]] = shl i64 [[TMP42]], 4 -+; CHECK-NEXT: [[TMP47:%.*]] = add nuw nsw i64 [[TMP46]], 8 -+; CHECK-NEXT: [[SCEVGEP30:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP47]] -+; CHECK-NEXT: [[BOUND0:%.*]] = icmp ult ptr [[A]], [[SCEVGEP29]] -+; CHECK-NEXT: [[BOUND1:%.*]] = icmp ult ptr [[B]], [[SCEVGEP28]] -+; CHECK-NEXT: [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]] -+; CHECK-NEXT: [[BOUND031:%.*]] = icmp ult ptr [[A]], [[SCEVGEP30]] -+; CHECK-NEXT: [[BOUND132:%.*]] = icmp ult ptr [[B]], [[SCEVGEP28]] -+; CHECK-NEXT: [[FOUND_CONFLICT33:%.*]] = and i1 [[BOUND031]], [[BOUND132]] -+; CHECK-NEXT: [[CONFLICT_RDX:%.*]] = or i1 [[FOUND_CONFLICT]], [[FOUND_CONFLICT33]] -+; CHECK-NEXT: br i1 [[CONFLICT_RDX]], label %[[SCALAR_PH]], label %[[VECTOR_PH:.*]] -+; CHECK: [[VECTOR_PH]]: -+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], 4 -+; CHECK-NEXT: [[TMP48:%.*]] = icmp eq i64 [[N_MOD_VF]], 0 -+; CHECK-NEXT: [[TMP49:%.*]] = select i1 [[TMP48]], i64 4, i64 [[N_MOD_VF]] -+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP1]], [[TMP49]] -+; CHECK-NEXT: [[IND_END:%.*]] = mul i64 [[N_VEC]], 8 -+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] -+; CHECK: [[VECTOR_BODY]]: -+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -+; CHECK-NEXT: [[VEC_IND:%.*]] = phi <4 x i64> [ , %[[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], %[[VECTOR_BODY]] ] -+; CHECK-NEXT: [[OFFSET_IDX:%.*]] = mul i64 [[INDEX]], 8 -+; CHECK-NEXT: [[TMP50:%.*]] = add i64 [[OFFSET_IDX]], 0 -+; CHECK-NEXT: [[TMP51:%.*]] = lshr exact i64 [[TMP50]], 1 -+; CHECK-NEXT: [[TMP52:%.*]] = getelementptr i32, ptr [[B]], i64 [[TMP51]] -+; CHECK-NEXT: [[TMP53:%.*]] = getelementptr i32, ptr [[TMP52]], i32 0 -+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x i32>, ptr [[TMP53]], align 4 -+; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x i32> [[WIDE_VEC]], <16 x i32> poison, <4 x i32> -+; CHECK-NEXT: [[STRIDED_VEC34:%.*]] = shufflevector <16 x i32> [[WIDE_VEC]], <16 x i32> poison, <4 x i32> -+; CHECK-NEXT: [[TMP54:%.*]] = getelementptr i32, ptr [[B]], <4 x i64> [[VEC_IND]] -+; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP54]], i32 4, <4 x i1> , <4 x i32> poison), !alias.scope [[META6:![0-9]+]] -+; CHECK-NEXT: [[TMP55:%.*]] = or disjoint i64 [[TMP50]], 7 -+; CHECK-NEXT: [[TMP56:%.*]] = getelementptr i32, ptr [[A]], i64 [[TMP55]] -+; CHECK-NEXT: [[TMP57:%.*]] = getelementptr i32, ptr [[TMP56]], i32 -7 -+; CHECK-NEXT: [[TMP58:%.*]] = shufflevector <4 x i32> [[STRIDED_VEC]], <4 x i32> zeroinitializer, <8 x i32> -+; CHECK-NEXT: [[TMP59:%.*]] = shufflevector <4 x i32> [[STRIDED_VEC34]], <4 x i32> zeroinitializer, <8 x i32> -+; CHECK-NEXT: [[TMP60:%.*]] = shufflevector <4 x i32> [[WIDE_MASKED_GATHER]], <4 x i32> zeroinitializer, <8 x i32> -+; CHECK-NEXT: [[TMP61:%.*]] = shufflevector <8 x i32> [[TMP58]], <8 x i32> [[TMP59]], <16 x i32> -+; CHECK-NEXT: [[TMP62:%.*]] = shufflevector <8 x i32> [[TMP60]], <8 x i32> zeroinitializer, <16 x i32> -+; CHECK-NEXT: [[TMP63:%.*]] = shufflevector <16 x i32> [[TMP61]], <16 x i32> [[TMP62]], <32 x i32> -+; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <32 x i32> [[TMP63]], <32 x i32> poison, <32 x i32> -+; CHECK-NEXT: store <32 x i32> [[INTERLEAVED_VEC]], ptr [[TMP57]], align 4 -+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 -+; CHECK-NEXT: [[VEC_IND_NEXT]] = add <4 x i64> [[VEC_IND]], -+; CHECK-NEXT: [[TMP64:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -+; CHECK-NEXT: br i1 [[TMP64]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]] -+; CHECK: [[MIDDLE_BLOCK]]: -+; CHECK-NEXT: br label %[[SCALAR_PH]] -+; CHECK: [[SCALAR_PH]]: -+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ], [ 0, %[[VECTOR_SCEVCHECK]] ], [ 0, %[[VECTOR_MEMCHECK]] ] -+; CHECK-NEXT: br label %[[LOOP:.*]] -+; CHECK: [[LOOP]]: -+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT_7:%.*]], %[[LOOP]] ] -+; CHECK-NEXT: [[SHR_1:%.*]] = lshr exact i64 [[IV]], 1 -+; CHECK-NEXT: [[GEP_B:%.*]] = getelementptr nusw i32, ptr [[B]], i64 [[SHR_1]] -+; CHECK-NEXT: [[L:%.*]] = load i32, ptr [[GEP_B]], align 4 -+; CHECK-NEXT: [[GEP_A:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV]] -+; CHECK-NEXT: store i32 [[L]], ptr [[GEP_A]], align 4 -+; CHECK-NEXT: [[IV_NEXT:%.*]] = or disjoint i64 [[IV]], 1 -+; CHECK-NEXT: [[GEP_A_1:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_1]], align 4 -+; CHECK-NEXT: [[IV_NEXT_1:%.*]] = or disjoint i64 [[IV]], 2 -+; CHECK-NEXT: [[SHR_2:%.*]] = lshr exact i64 [[IV_NEXT_1]], 1 -+; CHECK-NEXT: [[GEP_B_2:%.*]] = getelementptr i32, ptr [[B]], i64 [[SHR_2]] -+; CHECK-NEXT: [[TMP65:%.*]] = load i32, ptr [[GEP_B_2]], align 4 -+; CHECK-NEXT: [[GEP_A_2:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_1]] -+; CHECK-NEXT: store i32 [[TMP65]], ptr [[GEP_A_2]], align 4 -+; CHECK-NEXT: [[IV_NEXT_2:%.*]] = or disjoint i64 [[IV]], 3 -+; CHECK-NEXT: [[GEP_A_3:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_2]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_3]], align 4 -+; CHECK-NEXT: [[IV_NEXT_3:%.*]] = or disjoint i64 [[IV]], 4 -+; CHECK-NEXT: [[GEP_B_4:%.*]] = getelementptr i32, ptr [[B]], i64 [[IV]] -+; CHECK-NEXT: [[TMP66:%.*]] = load i32, ptr [[GEP_B_4]], align 4 -+; CHECK-NEXT: [[GEP_A_4:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_3]] -+; CHECK-NEXT: store i32 [[TMP66]], ptr [[GEP_A_4]], align 4 -+; CHECK-NEXT: [[IV_NEXT_4:%.*]] = or disjoint i64 [[IV]], 5 -+; CHECK-NEXT: [[GEP_A_5:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_4]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_5]], align 4 -+; CHECK-NEXT: [[IV_NEXT_5:%.*]] = or disjoint i64 [[IV]], 6 -+; CHECK-NEXT: [[GEP_A_6:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_5]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_6]], align 4 -+; CHECK-NEXT: [[IV_NEXT_6:%.*]] = or disjoint i64 [[IV]], 7 -+; CHECK-NEXT: [[GEP_A_7:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_6]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_7]], align 4 -+; CHECK-NEXT: [[IV_NEXT_7]] = add nuw nsw i64 [[IV]], 8 -+; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV]], [[N]] -+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT:.*]], label %[[LOOP]], !llvm.loop [[LOOP10:![0-9]+]] -+; CHECK: [[EXIT]]: -+; CHECK-NEXT: ret void -+; -+entry: -+ br label %loop -+ -+loop: -+ %iv = phi i64 [ 0, %entry ], [ %iv.next.7, %loop ] -+ %shr.1 = lshr exact i64 %iv, 1 -+ %gep.B = getelementptr nusw i32, ptr %B, i64 %shr.1 -+ %l = load i32, ptr %gep.B, align 4 -+ %gep.A = getelementptr i32, ptr %A, i64 %iv -+ store i32 %l, ptr %gep.A, align 4 -+ %iv.next = or disjoint i64 %iv, 1 -+ %gep.A.1 = getelementptr i32, ptr %A, i64 %iv.next -+ store i32 0, ptr %gep.A.1, align 4 -+ %iv.next.1 = or disjoint i64 %iv, 2 -+ %shr.2 = lshr exact i64 %iv.next.1, 1 -+ %gep.B.2 = getelementptr i32, ptr %B, i64 %shr.2 -+ %1 = load i32, ptr %gep.B.2, align 4 -+ %gep.A.2 = getelementptr i32, ptr %A, i64 %iv.next.1 -+ store i32 %1, ptr %gep.A.2, align 4 -+ %iv.next.2 = or disjoint i64 %iv, 3 -+ %gep.A.3 = getelementptr i32, ptr %A, i64 %iv.next.2 -+ store i32 0, ptr %gep.A.3, align 4 -+ %iv.next.3 = or disjoint i64 %iv, 4 -+ %gep.B.4 = getelementptr i32, ptr %B, i64 %iv -+ %2 = load i32, ptr %gep.B.4, align 4 -+ %gep.A.4 = getelementptr i32, ptr %A, i64 %iv.next.3 -+ store i32 %2, ptr %gep.A.4, align 4 -+ %iv.next.4 = or disjoint i64 %iv, 5 -+ %gep.A.5 = getelementptr i32, ptr %A, i64 %iv.next.4 -+ store i32 0, ptr %gep.A.5, align 4 -+ %iv.next.5 = or disjoint i64 %iv, 6 -+ %gep.A.6 = getelementptr i32, ptr %A, i64 %iv.next.5 -+ store i32 0, ptr %gep.A.6, align 4 -+ %iv.next.6 = or disjoint i64 %iv, 7 -+ %gep.A.7 = getelementptr i32, ptr %A, i64 %iv.next.6 -+ store i32 0, ptr %gep.A.7, align 4 -+ %iv.next.7 = add nuw nsw i64 %iv, 8 -+ %ec = icmp eq i64 %iv, %N -+ br i1 %ec, label %exit, label %loop -+ -+exit: -+ ret void -+} -+ -+attributes #0 = { "target-features"="+sse4.2" } -+attributes #1 = { "min-legal-vector-width"="0" "target-cpu"="cascadelake" } -+ - ;. - ; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]} - ; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1} - ; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"} - ; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META1]]} -+; CHECK: [[LOOP4]] = distinct !{[[LOOP4]], [[META1]], [[META2]]} -+; CHECK: [[LOOP5]] = distinct !{[[LOOP5]], [[META1]]} -+; CHECK: [[META6]] = !{[[META7:![0-9]+]]} -+; CHECK: [[META7]] = distinct !{[[META7]], [[META8:![0-9]+]]} -+; CHECK: [[META8]] = distinct !{[[META8]], !"LVerDomain"} -+; CHECK: [[LOOP9]] = distinct !{[[LOOP9]], [[META1]], [[META2]]} -+; CHECK: [[LOOP10]] = distinct !{[[LOOP10]], [[META1]]} - ;. -diff -ruN --strip-trailing-cr a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp ---- a/llvm/unittests/IR/PatternMatch.cpp -+++ b/llvm/unittests/IR/PatternMatch.cpp -@@ -2235,7 +2235,7 @@ - MutableConstTestTypes; - TYPED_TEST_SUITE(MutableConstTest, MutableConstTestTypes, ); +@@ -123,15 +113,6 @@ + return Stack.back().get(); + } --TYPED_TEST(MutableConstTest, ICmp) { -+TYPED_TEST(MutableConstTest, /* FIXME: UAR bug */ DISABLED_ICmp) { - auto &IRB = PatternMatchTest::IRB; +- TimeTraceProfilerEntry * +- begin(std::string Name, llvm::function_ref Metadata, +- bool AsyncEvent = false) { +- Stack.emplace_back(std::make_unique( +- ClockType::now(), TimePointType(), std::move(Name), Metadata(), +- AsyncEvent)); +- return Stack.back().get(); +- } +- + void end() { + assert(!Stack.empty() && "Must call begin() first"); + end(*Stack.back()); +@@ -203,15 +184,8 @@ + J.attribute("dur", DurUs); + } + J.attribute("name", E.Name); +- if (!E.Metadata.isEmpty()) { +- J.attributeObject("args", [&] { +- if (!E.Metadata.Detail.empty()) +- J.attribute("detail", E.Metadata.Detail); +- if (!E.Metadata.File.empty()) +- J.attribute("file", E.Metadata.File); +- if (E.Metadata.Line > 0) +- J.attribute("line", E.Metadata.Line); +- }); ++ if (!E.Detail.empty()) { ++ J.attributeObject("args", [&] { J.attribute("detail", E.Detail); }); + } + }); + +@@ -333,25 +307,14 @@ + + // Minimum time granularity (in microseconds) + const unsigned TimeTraceGranularity; +- +- // Make time trace capture verbose event details (e.g. source filenames). This +- // can increase the size of the output by 2-3 times. +- const bool TimeTraceVerbose; + }; - typedef std::tuple_element_t<0, TypeParam> ValueType; -@@ -2319,7 +2319,7 @@ - .match((InstructionType)IRB.CreateICmp(Pred, L, R))); +-bool llvm::isTimeTraceVerbose() { +- return getTimeTraceProfilerInstance() && +- getTimeTraceProfilerInstance()->TimeTraceVerbose; +-} +- + void llvm::timeTraceProfilerInitialize(unsigned TimeTraceGranularity, +- StringRef ProcName, +- bool TimeTraceVerbose) { ++ StringRef ProcName) { + assert(TimeTraceProfilerInstance == nullptr && + "Profiler should not be initialized"); + TimeTraceProfilerInstance = new TimeTraceProfiler( +- TimeTraceGranularity, llvm::sys::path::filename(ProcName), +- TimeTraceVerbose); ++ TimeTraceGranularity, llvm::sys::path::filename(ProcName)); } --TYPED_TEST(MutableConstTest, FCmp) { -+TYPED_TEST(MutableConstTest, /* FIXME: UAR bug */ DISABLED_FCmp) { - auto &IRB = PatternMatchTest::IRB; - - typedef std::tuple_element_t<0, TypeParam> ValueType; -diff -ruN --strip-trailing-cr a/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn b/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn ---- a/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn -+++ b/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn -@@ -108,6 +108,7 @@ - "Targets/DirectX.cpp", - "Targets/Hexagon.cpp", - "Targets/Lanai.cpp", -+ "Targets/Le64.cpp", - "Targets/LoongArch.cpp", - "Targets/M68k.cpp", - "Targets/MSP430.cpp", -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -@@ -254,6 +254,7 @@ - hdrs = ["src/__support/macros/optimization.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ":__support_macros_properties_compiler", - ], - ) -@@ -261,6 +262,9 @@ - libc_support_library( - name = "__support_macros_sanitizer", - hdrs = ["src/__support/macros/sanitizer.h"], -+ deps = [ -+ ":__support_macros_config", -+ ], - ) - - libc_support_library( -@@ -271,6 +275,7 @@ - ], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ":__support_macros_properties_architectures", - ], - ) -@@ -280,6 +285,7 @@ - hdrs = ["src/__support/CPP/algorithm.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ], - ) - -@@ -317,6 +323,7 @@ - hdrs = ["src/__support/CPP/bitset.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ], - ) - -@@ -334,6 +341,7 @@ - hdrs = ["src/__support/CPP/expected.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ], - ) - -@@ -424,6 +432,7 @@ - ], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ":__support_macros_properties_types", - ":llvm_libc_macros_stdfix_macros", - ], -@@ -573,7 +582,10 @@ - libc_support_library( - name = "__support_str_to_num_result", - hdrs = ["src/__support/str_to_num_result.h"], -- deps = [":__support_macros_attributes"], -+ deps = [ -+ ":__support_macros_attributes", -+ ":__support_macros_config", -+ ], - ) - - libc_support_library( -@@ -612,7 +624,10 @@ - libc_support_library( - name = "__support_ctype_utils", - hdrs = ["src/__support/ctype_utils.h"], -- deps = [":__support_macros_attributes"], -+ deps = [ -+ ":__support_macros_attributes", -+ ":__support_macros_config", -+ ], - ) - - libc_support_library( -@@ -785,6 +800,7 @@ - hdrs = ["src/__support/FPUtil/rounding_mode.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ":hdr_fenv_macros", - ], - ) -@@ -1126,6 +1142,7 @@ - hdrs = ["src/__support/threads/sleep.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ], - ) - -@@ -3408,9 +3425,9 @@ - ":__support_arg_list", - ":__support_file_file", - ":__support_macros_attributes", -- ":types_FILE", - ":printf_main", - ":printf_writer", -+ ":types_FILE", - ], - ) - -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/libc_build_rules.bzl b/utils/bazel/llvm-project-overlay/libc/libc_build_rules.bzl ---- a/utils/bazel/llvm-project-overlay/libc/libc_build_rules.bzl -+++ b/utils/bazel/llvm-project-overlay/libc/libc_build_rules.bzl -@@ -43,7 +43,7 @@ - name = name, - copts = copts + libc_common_copts(), - local_defines = local_defines + LIBC_CONFIGURE_OPTIONS, -- deps = deps + ["//libc:__support_macros_config"], -+ deps = deps, - linkstatic = 1, - **kwargs - ) -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/test/src/math/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/test/src/math/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/test/src/math/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/test/src/math/BUILD.bazel -@@ -298,8 +298,8 @@ - "//libc:__support_fputil_fp_bits", - "//libc:__support_fputil_manipulation_functions", - "//libc:hdr_math_macros", -- "//libc/test/UnitTest:fp_test_helpers", - "//libc/test/UnitTest:LibcUnitTest", -+ "//libc/test/UnitTest:fp_test_helpers", - ], - ) - -@@ -559,7 +559,10 @@ - libc_support_library( - name = "sdcomp26094", - hdrs = ["sdcomp26094.h"], -- deps = ["//libc:__support_cpp_array"], -+ deps = [ -+ "//libc:__support_cpp_array", -+ "//libc:__support_macros_config", -+ ], - ) - - math_test( -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/test/src/string/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/test/src/string/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/test/src/string/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/test/src/string/BUILD.bazel -@@ -121,6 +121,7 @@ - deps = [ - "//libc:__support_cpp_span", - "//libc:__support_libc_assert", -+ "//libc:__support_macros_config", - "//libc:__support_macros_sanitizer", - "//libc:string_memory_utils", - ], -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/test/UnitTest/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/test/UnitTest/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/test/UnitTest/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/test/UnitTest/BUILD.bazel -@@ -18,6 +18,7 @@ - "//libc:__support_big_int", - "//libc:__support_cpp_string", - "//libc:__support_cpp_string_view", -+ "//libc:__support_macros_config", - "//libc:__support_macros_properties_types", - "//libc:__support_osutil_io", - "//libc:__support_uint128", -@@ -52,6 +53,7 @@ - "//libc:__support_fputil_fp_bits", - "//libc:__support_fputil_fpbits_str", - "//libc:__support_fputil_rounding_mode", -+ "//libc:__support_macros_config", - "//libc:__support_macros_properties_architectures", - "//libc:__support_macros_properties_types", - "//libc:__support_stringutil", -@@ -89,10 +91,11 @@ - "//libc:__support_fputil_fp_bits", - "//libc:__support_fputil_fpbits_str", - "//libc:__support_fputil_rounding_mode", -+ "//libc:__support_macros_config", - "//libc:__support_macros_properties_architectures", -+ "//libc:hdr_fenv_macros", - "//libc:hdr_math_macros", -- "//libc:hdr_fenv_macros", -- "//libc:types_fenv_t", -+ "//libc:types_fenv_t", - ], - ) - -@@ -110,6 +113,7 @@ - "//libc:__support_cpp_bitset", - "//libc:__support_cpp_span", - "//libc:__support_cpp_type_traits", -+ "//libc:__support_macros_config", - ], - ) + // Removes all TimeTraceProfilerInstances. +@@ -418,14 +381,6 @@ + return nullptr; + } -@@ -125,6 +129,7 @@ - ":LibcUnitTest", - ":string_utils", - "//libc:__support_fputil_fp_bits", -+ "//libc:__support_macros_config", - "//libc:printf_core_structs", - ], - ) -@@ -138,5 +143,6 @@ - "//libc:__support_big_int", - "//libc:__support_cpp_string", - "//libc:__support_cpp_type_traits", -+ "//libc:__support_macros_config", - ], - ) -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/utils/MPFRWrapper/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/utils/MPFRWrapper/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/utils/MPFRWrapper/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/utils/MPFRWrapper/BUILD.bazel -@@ -48,6 +48,7 @@ - "//libc:__support_cpp_type_traits", - "//libc:__support_fputil_fp_bits", - "//libc:__support_fputil_fpbits_str", -+ "//libc:__support_macros_config", - "//libc:__support_macros_properties_types", - "//libc:hdr_math_macros", - "//libc/test/UnitTest:LibcUnitTest", -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -2900,6 +2900,7 @@ - ":IR", - ":LoopLikeInterface", - ":SCFDialect", -+ ":SCFToControlFlow", - ":SCFTransformOpsIncGen", - ":SCFTransforms", - ":SCFUtils", +-TimeTraceProfilerEntry * +-llvm::timeTraceProfilerBegin(StringRef Name, +- llvm::function_ref Metadata) { +- if (TimeTraceProfilerInstance != nullptr) +- return TimeTraceProfilerInstance->begin(std::string(Name), Metadata, false); +- return nullptr; +-} +- + TimeTraceProfilerEntry *llvm::timeTraceAsyncProfilerBegin(StringRef Name, + StringRef Detail) { + if (TimeTraceProfilerInstance != nullptr) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index a108e965dd0086..2949b73a155af1 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "acc159aea1e641e3694ab8fe5faa231788077011" - LLVM_SHA256 = "ff2d0c2d9dd22eb39b3d135bcf0cf91008b395de797f543e32790df372945d13" + LLVM_COMMIT = "84658fb82b67fc22ecba1560d0cddd09f9104178" + LLVM_SHA256 = "b4a50d36a8ab0284f7022f61bbf07a2fb3ea25c6bb2cc422d2418c23b61366da" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index ef740f479ad0f4..ff323785844790 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "05a83632728cbdf172bb92e3fd644487b74275f6" - SHARDY_SHA256 = "d89ae97cdfdbc5a192b90e7028a3b06873d2a8db5ffb092c2cd0bd4e30b29806" + SHARDY_COMMIT = "d889df1c54b8cd02d90a44aff7bd485340b4774d" + SHARDY_SHA256 = "5a6a83cbae22dfe0940825da944d48ef0968ff7f74ee38ea2d32b19443a10d8c" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index 7af6db90fd2b4f..c7f7475c35588c 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -2,1355 +2,637 @@ Auto generated patch. Do not edit or delete it, even if empty. diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst -@@ -40,8 +40,6 @@ - - Setting the deprecated CMake variable ``GCC_INSTALL_PREFIX`` (which sets the - default ``--gcc-toolchain=``) now leads to a fatal error. +@@ -750,9 +750,6 @@ + - Clang now specifies that using ``auto`` in a lambda parameter is a C++14 extension when + appropriate. (`#46059: `_). --- The ``le32`` and ``le64`` targets have been removed. +-- Clang now adds source file infomation for template instantiations as ``event["args"]["filename"]``. This +- added behind an option ``-ftime-trace-verbose``. This is expected to increase the size of trace by 2-3 times. - - C/C++ Language Potentially Breaking Changes - ------------------------------------------- - -diff -ruN --strip-trailing-cr a/clang/docs/tools/clang-formatted-files.txt b/clang/docs/tools/clang-formatted-files.txt ---- a/clang/docs/tools/clang-formatted-files.txt -+++ b/clang/docs/tools/clang-formatted-files.txt -@@ -362,6 +362,7 @@ - clang/lib/Basic/Targets/BPF.h - clang/lib/Basic/Targets/Hexagon.h - clang/lib/Basic/Targets/Lanai.h -+clang/lib/Basic/Targets/Le64.h - clang/lib/Basic/Targets/M68k.h - clang/lib/Basic/Targets/MSP430.h - clang/lib/Basic/Targets/NVPTX.cpp -diff -ruN --strip-trailing-cr a/clang/lib/Basic/CMakeLists.txt b/clang/lib/Basic/CMakeLists.txt ---- a/clang/lib/Basic/CMakeLists.txt -+++ b/clang/lib/Basic/CMakeLists.txt -@@ -102,6 +102,7 @@ - Targets/DirectX.cpp - Targets/Hexagon.cpp - Targets/Lanai.cpp -+ Targets/Le64.cpp - Targets/LoongArch.cpp - Targets/M68k.cpp - Targets/MSP430.cpp -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/Le64.cpp b/clang/lib/Basic/Targets/Le64.cpp ---- a/clang/lib/Basic/Targets/Le64.cpp -+++ b/clang/lib/Basic/Targets/Le64.cpp -@@ -0,0 +1,30 @@ -+//===--- Le64.cpp - Implement Le64 target feature support -----------------===// -+// -+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -+// See https://llvm.org/LICENSE.txt for license information. -+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -+// -+//===----------------------------------------------------------------------===// -+// -+// This file implements Le64 TargetInfo objects. -+// -+//===----------------------------------------------------------------------===// -+ -+#include "Le64.h" -+#include "Targets.h" -+#include "clang/Basic/Builtins.h" -+#include "clang/Basic/MacroBuilder.h" -+#include "clang/Basic/TargetBuiltins.h" -+ -+using namespace clang; -+using namespace clang::targets; -+ -+ArrayRef Le64TargetInfo::getTargetBuiltins() const { -+ return {}; -+} -+ -+void Le64TargetInfo::getTargetDefines(const LangOptions &Opts, -+ MacroBuilder &Builder) const { -+ DefineStd(Builder, "unix", Opts); -+ defineCPUMacros(Builder, "le64", /*Tuning=*/false); -+} -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/Le64.h b/clang/lib/Basic/Targets/Le64.h ---- a/clang/lib/Basic/Targets/Le64.h -+++ b/clang/lib/Basic/Targets/Le64.h -@@ -0,0 +1,64 @@ -+//===--- Le64.h - Declare Le64 target feature support -----------*- C++ -*-===// -+// -+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -+// See https://llvm.org/LICENSE.txt for license information. -+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -+// -+//===----------------------------------------------------------------------===// -+// -+// This file declares Le64 TargetInfo objects. -+// -+//===----------------------------------------------------------------------===// -+ -+#ifndef LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H -+#define LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H -+ -+#include "clang/Basic/TargetInfo.h" -+#include "clang/Basic/TargetOptions.h" -+#include "llvm/Support/Compiler.h" -+#include "llvm/TargetParser/Triple.h" -+ -+namespace clang { -+namespace targets { -+ -+class LLVM_LIBRARY_VISIBILITY Le64TargetInfo : public TargetInfo { -+ -+public: -+ Le64TargetInfo(const llvm::Triple &Triple, const TargetOptions &) -+ : TargetInfo(Triple) { -+ NoAsmVariants = true; -+ LongWidth = LongAlign = PointerWidth = PointerAlign = 64; -+ MaxAtomicPromoteWidth = MaxAtomicInlineWidth = 64; -+ resetDataLayout("e-m:e-v128:32-v16:16-v32:32-v96:32-n8:16:32:64-S128"); -+ } -+ -+ void getTargetDefines(const LangOptions &Opts, -+ MacroBuilder &Builder) const override; -+ -+ ArrayRef getTargetBuiltins() const override; -+ -+ BuiltinVaListKind getBuiltinVaListKind() const override { -+ return TargetInfo::PNaClABIBuiltinVaList; -+ } -+ -+ std::string_view getClobbers() const override { return ""; } -+ -+ ArrayRef getGCCRegNames() const override { -+ return std::nullopt; -+ } -+ -+ ArrayRef getGCCRegAliases() const override { -+ return std::nullopt; -+ } -+ -+ bool validateAsmConstraint(const char *&Name, -+ TargetInfo::ConstraintInfo &Info) const override { -+ return false; -+ } -+ -+ bool hasProtectedVisibility() const override { return false; } -+}; -+ -+} // namespace targets -+} // namespace clang -+#endif // LLVM_CLANG_LIB_BASIC_TARGETS_LE64_H -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets/OSTargets.h b/clang/lib/Basic/Targets/OSTargets.h ---- a/clang/lib/Basic/Targets/OSTargets.h -+++ b/clang/lib/Basic/Targets/OSTargets.h -@@ -841,6 +841,9 @@ - "i64:64-i128:128-n8:16:32:64-S128"); - } else if (Triple.getArch() == llvm::Triple::mipsel) { - // Handled on mips' setDataLayout. -+ } else { -+ assert(Triple.getArch() == llvm::Triple::le32); -+ this->resetDataLayout("e-p:32:32-i64:64"); - } - } - }; -diff -ruN --strip-trailing-cr a/clang/lib/Basic/Targets.cpp b/clang/lib/Basic/Targets.cpp ---- a/clang/lib/Basic/Targets.cpp -+++ b/clang/lib/Basic/Targets.cpp -@@ -23,6 +23,7 @@ - #include "Targets/DirectX.h" - #include "Targets/Hexagon.h" - #include "Targets/Lanai.h" -+#include "Targets/Le64.h" - #include "Targets/LoongArch.h" - #include "Targets/M68k.h" - #include "Targets/MSP430.h" -@@ -343,6 +344,17 @@ - return std::make_unique(Triple, Opts); - } - -+ case llvm::Triple::le32: -+ switch (os) { -+ case llvm::Triple::NaCl: -+ return std::make_unique>(Triple, Opts); -+ default: -+ return nullptr; -+ } -+ -+ case llvm::Triple::le64: -+ return std::make_unique(Triple, Opts); -+ - case llvm::Triple::ppc: - switch (os) { - case llvm::Triple::Linux: -diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp ---- a/clang/lib/CodeGen/CodeGenModule.cpp -+++ b/clang/lib/CodeGen/CodeGenModule.cpp -@@ -116,6 +116,8 @@ - default: - return createDefaultTargetCodeGenInfo(CGM); - -+ case llvm::Triple::le32: -+ return createPNaClTargetCodeGenInfo(CGM); - case llvm::Triple::m68k: - return createM68kTargetCodeGenInfo(CGM); - case llvm::Triple::mips: -diff -ruN --strip-trailing-cr a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp ---- a/clang/lib/CodeGen/ItaniumCXXABI.cpp -+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp -@@ -576,6 +576,13 @@ - return new XLCXXABI(CGM); - - case TargetCXXABI::GenericItanium: -+ if (CGM.getContext().getTargetInfo().getTriple().getArch() -+ == llvm::Triple::le32) { -+ // For PNaCl, use ARM-style method pointers so that PNaCl code -+ // does not assume anything about the alignment of function -+ // pointers. -+ return new ItaniumCXXABI(CGM, /*UseARMMethodPtrABI=*/true); -+ } - return new ItaniumCXXABI(CGM); - - case TargetCXXABI::Microsoft: + Improvements to Coverage Mapping + -------------------------------- + +diff -ruN --strip-trailing-cr a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td +--- a/clang/include/clang/Driver/Options.td ++++ b/clang/include/clang/Driver/Options.td +@@ -3998,10 +3998,6 @@ + HelpText<"Minimum time granularity (in microseconds) traced by time profiler">, + Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, + MarshallingInfoInt, "500u">; +-def ftime_trace_verbose : Joined<["-"], "ftime-trace-verbose">, Group, +- HelpText<"Make time trace capture verbose event details (e.g. source filenames). This can increase the size of the output by 2-3 times">, +- Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, +- MarshallingInfoFlag>; + def ftime_trace_EQ : Joined<["-"], "ftime-trace=">, Group, + HelpText<"Similar to -ftime-trace. Specify the JSON file or a directory which will contain the JSON file">, + Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, +diff -ruN --strip-trailing-cr a/clang/include/clang/Frontend/FrontendOptions.h b/clang/include/clang/Frontend/FrontendOptions.h +--- a/clang/include/clang/Frontend/FrontendOptions.h ++++ b/clang/include/clang/Frontend/FrontendOptions.h +@@ -580,11 +580,6 @@ + /// Minimum time granularity (in microseconds) traced by time profiler. + unsigned TimeTraceGranularity; + +- /// Make time trace capture verbose event details (e.g. source filenames). +- /// This can increase the size of the output by 2-3 times. +- LLVM_PREFERRED_TYPE(bool) +- unsigned TimeTraceVerbose : 1; +- + /// Path which stores the output files for -ftime-trace + std::string TimeTracePath; + +@@ -606,8 +601,7 @@ + EmitSymbolGraph(false), EmitExtensionSymbolGraphs(false), + EmitSymbolGraphSymbolLabelsForTesting(false), + EmitPrettySymbolGraphs(false), GenReducedBMI(false), +- UseClangIRPipeline(false), TimeTraceGranularity(500), +- TimeTraceVerbose(false) {} ++ UseClangIRPipeline(false), TimeTraceGranularity(500) {} + + /// getInputKindForExtension - Return the appropriate input kind for a file + /// extension. For example, "c" would return Language::C. diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp -@@ -3815,6 +3815,12 @@ - if (UseBuiltins) - A->render(Args, CmdArgs); +@@ -6757,7 +6757,6 @@ + if (const char *Name = C.getTimeTraceFile(&JA)) { + CmdArgs.push_back(Args.MakeArgString("-ftime-trace=" + Twine(Name))); + Args.AddLastArg(CmdArgs, options::OPT_ftime_trace_granularity_EQ); +- Args.AddLastArg(CmdArgs, options::OPT_ftime_trace_verbose); } -+ -+ // le32-specific flags: -+ // -fno-math-builtin: clang should not convert math builtins to intrinsics -+ // by default. -+ if (TC.getArch() == llvm::Triple::le32) -+ CmdArgs.push_back("-fno-math-builtin"); - } - bool Driver::getDefaultModuleCachePath(SmallVectorImpl &Result) { -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/bitfield-access-pad.c b/clang/test/CodeGen/bitfield-access-pad.c ---- a/clang/test/CodeGen/bitfield-access-pad.c -+++ b/clang/test/CodeGen/bitfield-access-pad.c -@@ -16,6 +16,7 @@ - // Configs that have expensive unaligned access - // Little Endian - // RUN: %clang_cc1 -triple=hexagon-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT-T %s -+// RUN: %clang_cc1 -triple=le64-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT-T %s - - // Big endian - // RUN: %clang_cc1 -triple=m68k-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT-T %s -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/bitfield-access-unit.c b/clang/test/CodeGen/bitfield-access-unit.c ---- a/clang/test/CodeGen/bitfield-access-unit.c -+++ b/clang/test/CodeGen/bitfield-access-unit.c -@@ -53,8 +53,8 @@ - // RUN: %clang_cc1 -triple=sparc-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT-STRICT %s - // RUN: %clang_cc1 -triple=tce-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT-STRICT %s - --// m68-elf is a strict alignment ISA with 4-byte aligned 64-bit or 2-byte --// aligned 32-bit integer types. This more compex to describe here. -+// Both le64-elf and m68-elf are strict alignment ISAs with 4-byte aligned -+// 64-bit or 2-byte aligned 32-bit integer types. This more compex to describe here. - - // If unaligned access is expensive don't stick these together. - struct A { -diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bitfield-access-empty.cpp b/clang/test/CodeGenCXX/bitfield-access-empty.cpp ---- a/clang/test/CodeGenCXX/bitfield-access-empty.cpp -+++ b/clang/test/CodeGenCXX/bitfield-access-empty.cpp -@@ -26,6 +26,7 @@ - // RUN: %clang_cc1 -triple=bpf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=csky %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=hexagon-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s -+// RUN: %clang_cc1 -triple=le64-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=loongarch32-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=nvptx-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s - // RUN: %clang_cc1 -triple=riscv32 %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT %s -diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/bitfield-access-tail.cpp b/clang/test/CodeGenCXX/bitfield-access-tail.cpp ---- a/clang/test/CodeGenCXX/bitfield-access-tail.cpp -+++ b/clang/test/CodeGenCXX/bitfield-access-tail.cpp -@@ -26,6 +26,7 @@ - // RUN: %clang_cc1 -triple=bpf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT64 %s - // RUN: %clang_cc1 -triple=csky %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s - // RUN: %clang_cc1 -triple=hexagon-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s -+// RUN: %clang_cc1 -triple=le64-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT64 %s - // RUN: %clang_cc1 -triple=loongarch32-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s - // RUN: %clang_cc1 -triple=nvptx-elf %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s - // RUN: %clang_cc1 -triple=riscv32 %s -emit-llvm -o /dev/null -fdump-record-layouts-simple | FileCheck --check-prefixes CHECK,LAYOUT,LAYOUT32 %s -diff -ruN --strip-trailing-cr a/clang/test/Preprocessor/predefined-macros-no-warnings.c b/clang/test/Preprocessor/predefined-macros-no-warnings.c ---- a/clang/test/Preprocessor/predefined-macros-no-warnings.c -+++ b/clang/test/Preprocessor/predefined-macros-no-warnings.c -@@ -75,6 +75,8 @@ - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple m68k - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple m68k-linux - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple m68k-netbsd -+// RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple le32-nacl -+// RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple le64 - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc-freebsd - // RUN: %clang_cc1 %s -Eonly -Wsystem-headers -Werror -triple ppc-netbsd -diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h ---- a/llvm/include/llvm/IR/PatternMatch.h -+++ b/llvm/include/llvm/IR/PatternMatch.h -@@ -1550,27 +1550,23 @@ - template - struct CmpClass_match { -- PredicateTy *Predicate; -+ PredicateTy &Predicate; - LHS_t L; - RHS_t R; - - // The evaluation order is always stable, regardless of Commutability. - // The LHS is always matched first. - CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) -- : Predicate(&Pred), L(LHS), R(RHS) {} -- CmpClass_match(const LHS_t &LHS, const RHS_t &RHS) -- : Predicate(nullptr), L(LHS), R(RHS) {} -+ : Predicate(Pred), L(LHS), R(RHS) {} - - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { -- if (Predicate) -- *Predicate = I->getPredicate(); -+ Predicate = I->getPredicate(); - return true; - } else if (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))) { -- if (Predicate) -- *Predicate = I->getSwappedPredicate(); -+ Predicate = I->getSwappedPredicate(); - return true; - } - } -@@ -1599,19 +1595,22 @@ - template - inline CmpClass_match - m_Cmp(const LHS &L, const RHS &R) { -- return CmpClass_match(L, R); -+ CmpInst::Predicate Unused; -+ return CmpClass_match(Unused, L, R); - } + if (Arg *A = Args.getLastArg(options::OPT_ftrapv_handler_EQ)) { +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp +--- a/clang/lib/Sema/SemaTemplateInstantiate.cpp ++++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp +@@ -3426,16 +3426,11 @@ + return true; + + llvm::TimeTraceScope TimeScope("InstantiateClass", [&]() { +- llvm::TimeTraceMetadata M; +- llvm::raw_string_ostream OS(M.Detail); ++ std::string Name; ++ llvm::raw_string_ostream OS(Name); + Instantiation->getNameForDiagnostic(OS, getPrintingPolicy(), + /*Qualified=*/true); +- if (llvm::isTimeTraceVerbose()) { +- auto Loc = SourceMgr.getExpansionLoc(Instantiation->getLocation()); +- M.File = SourceMgr.getFilename(Loc); +- M.Line = SourceMgr.getExpansionLineNumber(Loc); +- } +- return M; ++ return Name; + }); + + Pattern = PatternDef; +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp +--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp ++++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp +@@ -4966,16 +4966,11 @@ + } - template - inline CmpClass_match - m_ICmp(const LHS &L, const RHS &R) { -- return CmpClass_match(L, R); -+ ICmpInst::Predicate Unused; -+ return CmpClass_match(Unused, L, R); + llvm::TimeTraceScope TimeScope("InstantiateFunction", [&]() { +- llvm::TimeTraceMetadata M; +- llvm::raw_string_ostream OS(M.Detail); ++ std::string Name; ++ llvm::raw_string_ostream OS(Name); + Function->getNameForDiagnostic(OS, getPrintingPolicy(), + /*Qualified=*/true); +- if (llvm::isTimeTraceVerbose()) { +- auto Loc = SourceMgr.getExpansionLoc(Function->getLocation()); +- M.File = SourceMgr.getFilename(Loc); +- M.Line = SourceMgr.getExpansionLineNumber(Loc); +- } +- return M; ++ return Name; + }); + + // If we're performing recursive template instantiation, create our own +diff -ruN --strip-trailing-cr a/clang/test/Driver/ftime-trace-sections.cpp b/clang/test/Driver/ftime-trace-sections.cpp +--- a/clang/test/Driver/ftime-trace-sections.cpp ++++ b/clang/test/Driver/ftime-trace-sections.cpp +@@ -1,5 +1,5 @@ + // RUN: rm -rf %t && mkdir %t && cd %t +-// RUN: %clangxx -S -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -ftime-trace -ftime-trace-granularity=0 -o out %s + // RUN: %python %S/ftime-trace-sections.py < out.json + + template +diff -ruN --strip-trailing-cr a/clang/test/Driver/ftime-trace.cpp b/clang/test/Driver/ftime-trace.cpp +--- a/clang/test/Driver/ftime-trace.cpp ++++ b/clang/test/Driver/ftime-trace.cpp +@@ -1,18 +1,18 @@ + // RUN: rm -rf %t && mkdir -p %t && cd %t +-// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace -ftime-trace-granularity=0 -o out %s + // RUN: cat out.json \ + // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ + // RUN: | FileCheck %s +-// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=new-name.json -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=new-name.json -ftime-trace-granularity=0 -o out %s + // RUN: cat new-name.json \ + // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ + // RUN: | FileCheck %s + // RUN: mkdir dir1 dir2 +-// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir1 -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir1 -ftime-trace-granularity=0 -o out %s + // RUN: cat dir1/out.json \ + // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ + // RUN: | FileCheck %s +-// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir2/ -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s ++// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir2/ -ftime-trace-granularity=0 -o out %s + // RUN: cat dir2/out.json \ + // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ + // RUN: | FileCheck %s +@@ -34,33 +34,32 @@ + // RUN: mkdir d e f && cp %s d/a.cpp && touch d/b.c + + /// TODO: Support -fno-integrated-as. +-// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -fintegrated-as d/a.cpp -o e/a.o 2>&1 | FileCheck %s --check-prefix=COMPILE1 +-// COMPILE1: -cc1{{.*}} "-ftime-trace=e/a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" ++// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -fintegrated-as d/a.cpp -o e/a.o 2>&1 | FileCheck %s --check-prefix=COMPILE1 ++// COMPILE1: -cc1{{.*}} "-ftime-trace=e/a.json" "-ftime-trace-granularity=0" + +-// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=COMPILE2 +-// COMPILE2: -cc1{{.*}} "-ftime-trace=f/a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +-// COMPILE2: -cc1{{.*}} "-ftime-trace=f/b.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" ++// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=COMPILE2 ++// COMPILE2: -cc1{{.*}} "-ftime-trace=f/a.json" "-ftime-trace-granularity=0" ++// COMPILE2: -cc1{{.*}} "-ftime-trace=f/b.json" "-ftime-trace-granularity=0" + + /// -o specifies the link output. Create ${output}-${basename}.json. +-// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o e/x 2>&1 | FileCheck %s --check-prefix=LINK1 +-// LINK1: -cc1{{.*}} "-ftime-trace=e/x-a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +-// LINK1: -cc1{{.*}} "-ftime-trace=e/x-b.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" ++// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -o e/x 2>&1 | FileCheck %s --check-prefix=LINK1 ++// LINK1: -cc1{{.*}} "-ftime-trace=e/x-a.json" "-ftime-trace-granularity=0" ++// LINK1: -cc1{{.*}} "-ftime-trace=e/x-b.json" "-ftime-trace-granularity=0" + + /// -dumpdir is f/g, not ending with a path separator. We create f/g${basename}.json. +-// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o e/x -dumpdir f/g 2>&1 | FileCheck %s --check-prefix=LINK2 +-// LINK2: -cc1{{.*}} "-ftime-trace=f/ga.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +-// LINK2: -cc1{{.*}} "-ftime-trace=f/gb.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +- +-// RUN: %clang -### -ftime-trace=e -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o f/x -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=LINK3 +-// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}a-{{[^.]*}}.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" +-// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}b-{{[^.]*}}.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" ++// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -o e/x -dumpdir f/g 2>&1 | FileCheck %s --check-prefix=LINK2 ++// LINK2: -cc1{{.*}} "-ftime-trace=f/ga.json" "-ftime-trace-granularity=0" ++// LINK2: -cc1{{.*}} "-ftime-trace=f/gb.json" "-ftime-trace-granularity=0" ++ ++// RUN: %clang -### -ftime-trace=e -ftime-trace-granularity=0 d/a.cpp d/b.c -o f/x -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=LINK3 ++// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}a-{{[^.]*}}.json" "-ftime-trace-granularity=0" ++// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}b-{{[^.]*}}.json" "-ftime-trace-granularity=0" + +-// RUN: %clang -### -ftime-trace -ftime-trace=e -ftime-trace-granularity=1 -ftime-trace-verbose -xassembler d/a.cpp 2>&1 | \ ++// RUN: %clang -### -ftime-trace -ftime-trace=e -ftime-trace-granularity=1 -xassembler d/a.cpp 2>&1 | \ + // RUN: FileCheck %s --check-prefix=UNUSED + // UNUSED: warning: argument unused during compilation: '-ftime-trace' + // UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace=e' + // UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace-granularity=1' +-// UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace-verbose' + // UNUSED-NOT: warning: + + template +diff -ruN --strip-trailing-cr a/clang/tools/driver/cc1_main.cpp b/clang/tools/driver/cc1_main.cpp +--- a/clang/tools/driver/cc1_main.cpp ++++ b/clang/tools/driver/cc1_main.cpp +@@ -241,8 +241,7 @@ + + if (!Clang->getFrontendOpts().TimeTracePath.empty()) { + llvm::timeTraceProfilerInitialize( +- Clang->getFrontendOpts().TimeTraceGranularity, Argv0, +- Clang->getFrontendOpts().TimeTraceVerbose); ++ Clang->getFrontendOpts().TimeTraceGranularity, Argv0); + } + // --print-supported-cpus takes priority over the actual compilation. + if (Clang->getFrontendOpts().PrintSupportedCPUs) +diff -ruN --strip-trailing-cr a/clang/unittests/Support/TimeProfilerTest.cpp b/clang/unittests/Support/TimeProfilerTest.cpp +--- a/clang/unittests/Support/TimeProfilerTest.cpp ++++ b/clang/unittests/Support/TimeProfilerTest.cpp +@@ -10,15 +10,11 @@ + #include "clang/Frontend/FrontendActions.h" + #include "clang/Lex/PreprocessorOptions.h" + +-#include "llvm/ADT/StringMap.h" + #include "llvm/Support/JSON.h" +-#include "llvm/Support/Path.h" + #include "llvm/Support/TimeProfiler.h" +-#include "llvm/Support/VirtualFileSystem.h" + #include + + #include "gtest/gtest.h" +-#include + + using namespace clang; + using namespace llvm; +@@ -27,8 +23,7 @@ + + // Should be called before testing. + void setupProfiler() { +- timeTraceProfilerInitialize(/*TimeTraceGranularity=*/0, "test", +- /*TimeTraceVerbose=*/true); ++ timeTraceProfilerInitialize(/*TimeTraceGranularity=*/0, "test"); } - template - inline CmpClass_match - m_FCmp(const LHS &L, const RHS &R) { -- return CmpClass_match(L, R); -+ FCmpInst::Predicate Unused; -+ return CmpClass_match(Unused, L, R); + // Should be called after `compileFromString()`. +@@ -43,24 +38,14 @@ + + // Returns true if code compiles successfully. + // We only parse AST here. This is enough for constexpr evaluation. +-bool compileFromString(StringRef Code, StringRef Standard, StringRef File, +- llvm::StringMap Headers = {}) { ++bool compileFromString(StringRef Code, StringRef Standard, StringRef FileName) { + CompilerInstance Compiler; + Compiler.createDiagnostics(); + +- llvm::IntrusiveRefCntPtr FS( +- new llvm::vfs::InMemoryFileSystem()); +- FS->addFile(File, 0, MemoryBuffer::getMemBuffer(Code)); +- for (const auto &Header : Headers) { +- FS->addFile(Header.getKey(), 0, +- MemoryBuffer::getMemBuffer(Header.getValue())); +- } +- llvm::IntrusiveRefCntPtr Files( +- new FileManager(FileSystemOptions(), FS)); +- Compiler.setFileManager(Files.get()); +- + auto Invocation = std::make_shared(); +- std::vector Args = {Standard.data(), File.data()}; ++ Invocation->getPreprocessorOpts().addRemappedFile( ++ FileName, MemoryBuffer::getMemBuffer(Code).release()); ++ const char *Args[] = {Standard.data(), FileName.data()}; + CompilerInvocation::CreateFromArgs(*Invocation, Args, + Compiler.getDiagnostics()); + Compiler.setInvocation(std::move(Invocation)); +@@ -75,28 +60,13 @@ + return Compiler.ExecuteAction(Action); } - // Same as CmpClass, but instead of saving Pred as out output variable, match a -diff -ruN --strip-trailing-cr a/llvm/include/llvm/TargetParser/Triple.h b/llvm/include/llvm/TargetParser/Triple.h ---- a/llvm/include/llvm/TargetParser/Triple.h -+++ b/llvm/include/llvm/TargetParser/Triple.h -@@ -88,6 +88,8 @@ - xtensa, // Tensilica: Xtensa - nvptx, // NVPTX: 32-bit - nvptx64, // NVPTX: 64-bit -+ le32, // le32: generic little-endian 32-bit CPU (PNaCl) -+ le64, // le64: generic little-endian 64-bit CPU (PNaCl) - amdil, // AMDIL - amdil64, // AMDIL with 64-bit pointers - hsail, // AMD HSAIL -diff -ruN --strip-trailing-cr a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp ---- a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp -+++ b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp -@@ -128,7 +128,6 @@ - bool visitINSviGPR(MachineInstr &MI, unsigned Opc); - bool visitINSvi64lane(MachineInstr &MI); - bool visitFMOVDr(MachineInstr &MI); -- bool visitCopy(MachineInstr &MI); - bool runOnMachineFunction(MachineFunction &MF) override; +-std::string GetMetadata(json::Object *Event) { +- std::string Metadata; +- llvm::raw_string_ostream OS(Metadata); +- if (json::Object *Args = Event->getObject("args")) { +- if (auto Detail = Args->getString("detail")) +- OS << Detail; +- // Use only filename to not include os-specific path separators. +- if (auto File = Args->getString("file")) +- OS << ", " << llvm::sys::path::filename(*File); +- if (auto Line = Args->getInteger("line")) +- OS << ":" << *Line; +- } +- return Metadata; +-} +- + // Returns pretty-printed trace graph. + std::string buildTraceGraph(StringRef Json) { + struct EventRecord { + int64_t TimestampBegin; + int64_t TimestampEnd; +- std::string Name; +- std::string Metadata; ++ StringRef Name; ++ StringRef Detail; + }; + std::vector Events; + +@@ -111,13 +81,10 @@ + int64_t TimestampBegin = TraceEventObj->getInteger("ts").value_or(0); + int64_t TimestampEnd = + TimestampBegin + TraceEventObj->getInteger("dur").value_or(0); +- std::string Name = TraceEventObj->getString("name").value_or("").str(); +- std::string Metadata = GetMetadata(TraceEventObj); +- +- // Source events are asynchronous events and may not perfectly nest the +- // synchronous events. Skip testing them. +- if (Name == "Source") +- continue; ++ StringRef Name = TraceEventObj->getString("name").value_or(""); ++ StringRef Detail = ""; ++ if (json::Object *Args = TraceEventObj->getObject("args")) ++ Detail = Args->getString("detail").value_or(""); + + // This is a "summary" event, like "Total PerformPendingInstantiations", + // skip it +@@ -125,7 +92,7 @@ + continue; - StringRef getPassName() const override { -@@ -691,34 +690,6 @@ - return true; - } + Events.emplace_back( +- EventRecord{TimestampBegin, TimestampEnd, Name, Metadata}); ++ EventRecord{TimestampBegin, TimestampEnd, Name, Detail}); + } --// Across a basic-block we might have in i32 extract from a value that only --// operates on upper bits (for example a sxtw). We can replace the COPY with a --// new version skipping the sxtw. --bool AArch64MIPeepholeOpt::visitCopy(MachineInstr &MI) { -- Register InputReg = MI.getOperand(1).getReg(); -- if (MI.getOperand(1).getSubReg() != AArch64::sub_32 || -- !MRI->hasOneNonDBGUse(InputReg)) -- return false; + // There can be nested events that are very fast, for example: +@@ -165,9 +132,9 @@ + Stream << "| "; + } + Stream.write(Event.Name.data(), Event.Name.size()); +- if (!Event.Metadata.empty()) { ++ if (!Event.Detail.empty()) { + Stream << " ("; +- Stream.write(Event.Metadata.data(), Event.Metadata.size()); ++ Stream.write(Event.Detail.data(), Event.Detail.size()); + Stream << ")"; + } + Stream << "\n"; +@@ -178,7 +145,7 @@ + } // namespace + + TEST(TimeProfilerTest, ConstantEvaluationCxx20) { +- std::string Code = R"( ++ constexpr StringRef Code = R"( + void print(double value); + + namespace slow_namespace { +@@ -208,7 +175,8 @@ + setupProfiler(); + ASSERT_TRUE(compileFromString(Code, "-std=c++20", "test.cc")); + std::string Json = teardownProfiler(); +- ASSERT_EQ(R"( ++ std::string TraceGraph = buildTraceGraph(Json); ++ ASSERT_TRUE(TraceGraph == R"( + Frontend + | ParseDeclarationOrFunctionDefinition (test.cc:2:1) + | ParseDeclarationOrFunctionDefinition (test.cc:6:1) +@@ -234,54 +202,14 @@ + | ParseDeclarationOrFunctionDefinition (test.cc:25:1) + | | EvaluateAsInitializer (slow_init_list) + | PerformPendingInstantiations +-)", +- buildTraceGraph(Json)); +-} - -- MachineInstr *SrcMI = MRI->getUniqueVRegDef(InputReg); -- MachineInstr *CopyMI = SrcMI; -- while (SrcMI && SrcMI->isFullCopy() && -- MRI->hasOneNonDBGUse(SrcMI->getOperand(1).getReg())) -- SrcMI = MRI->getUniqueVRegDef(SrcMI->getOperand(1).getReg()); +-TEST(TimeProfilerTest, TemplateInstantiations) { +- std::string B_H = R"( +- template +- T fooB(T t) { +- return T(); +- } ++)"); + +- #define MacroTemp(x) template void foo##x(T) { T(); } +- )"; - -- if (!SrcMI || SrcMI->getOpcode() != AArch64::SBFMXri || -- SrcMI->getOperand(2).getImm() != 0 || SrcMI->getOperand(3).getImm() != 31) -- return false; +- std::string A_H = R"( +- #include "b.h" - -- Register SrcReg = SrcMI->getOperand(1).getReg(); -- MRI->constrainRegClass(SrcReg, MRI->getRegClass(InputReg)); -- MI.getOperand(1).setReg(SrcReg); -- if (CopyMI != SrcMI) -- CopyMI->eraseFromParent(); -- SrcMI->eraseFromParent(); -- return true; --} +- MacroTemp(MTA) - - bool AArch64MIPeepholeOpt::runOnMachineFunction(MachineFunction &MF) { - if (skipFunction(MF.getFunction())) - return false; -@@ -800,9 +771,6 @@ - case AArch64::FMOVDr: - Changed |= visitFMOVDr(MI); - break; -- case AArch64::COPY: -- Changed |= visitCopy(MI); -- break; - } - } - } -diff -ruN --strip-trailing-cr a/llvm/lib/Target/AArch64/peephole-sxtw.mir b/llvm/lib/Target/AArch64/peephole-sxtw.mir ---- a/llvm/lib/Target/AArch64/peephole-sxtw.mir -+++ b/llvm/lib/Target/AArch64/peephole-sxtw.mir -@@ -1,46 +0,0 @@ --# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py --# RUN: llc -run-pass=aarch64-mi-peephole-opt -o - -mtriple=aarch64-unknown-linux -verify-machineinstrs %s | FileCheck %s +- template +- void fooA(T t) { fooB(t); fooMTA(t); } +- )"; +- std::string Code = R"( +- #include "a.h" +- void user() { fooA(0); } +- )"; - ----- --name: removeSxtw --tracksRegLiveness: true --body: | -- bb.0.entry: -- liveins: $x0 -- ; CHECK-LABEL: name: removeSxtw -- ; CHECK: liveins: $x0 -- ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr64 = COPY $x0 -- ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr32sp = COPY [[COPY]].sub_32 -- ; CHECK-NEXT: [[ADDWri:%[0-9]+]]:gpr32sp = ADDWri [[COPY1]], 1, 0 -- ; CHECK-NEXT: $w0 = COPY [[ADDWri]] -- ; CHECK-NEXT: RET_ReallyLR implicit $w0 -- %0:gpr64 = COPY $x0 -- %1:gpr64 = SBFMXri %0:gpr64, 0, 31 -- %2:gpr32sp = COPY %1.sub_32:gpr64 -- %3:gpr32sp = ADDWri %2:gpr32sp, 1, 0 -- $w0 = COPY %3:gpr32sp -- RET_ReallyLR implicit $w0 --... ----- --name: extraCopy --tracksRegLiveness: true --body: | -- bb.0.entry: -- liveins: $x0 -- ; CHECK-LABEL: name: extraCopy -- ; CHECK: liveins: $x0 -- ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr64 = COPY $x0 -- ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr32sp = COPY [[COPY]].sub_32 -- ; CHECK-NEXT: [[ADDWri:%[0-9]+]]:gpr32sp = ADDWri [[COPY1]], 1, 0 -- ; CHECK-NEXT: $w0 = COPY [[ADDWri]] -- ; CHECK-NEXT: RET_ReallyLR implicit $w0 -- %0:gpr64 = COPY $x0 -- %1:gpr64 = SBFMXri %0:gpr64, 0, 31 -- %2:gpr64all = COPY %1:gpr64 -- %3:gpr32sp = COPY %2.sub_32:gpr64all -- %4:gpr32sp = ADDWri %3:gpr32sp, 1, 0 -- $w0 = COPY %4:gpr32sp -- RET_ReallyLR implicit $w0 --... -diff -ruN --strip-trailing-cr a/llvm/lib/TargetParser/Triple.cpp b/llvm/lib/TargetParser/Triple.cpp ---- a/llvm/lib/TargetParser/Triple.cpp -+++ b/llvm/lib/TargetParser/Triple.cpp -@@ -44,6 +44,8 @@ - case hsail: return "hsail"; - case kalimba: return "kalimba"; - case lanai: return "lanai"; -+ case le32: return "le32"; -+ case le64: return "le64"; - case loongarch32: return "loongarch32"; - case loongarch64: return "loongarch64"; - case m68k: return "m68k"; -@@ -197,6 +199,9 @@ - case nvptx: return "nvvm"; - case nvptx64: return "nvvm"; +- setupProfiler(); +- ASSERT_TRUE(compileFromString(Code, "-std=c++20", "test.cc", +- /*Headers=*/{{"a.h", A_H}, {"b.h", B_H}})); +- std::string Json = teardownProfiler(); +- ASSERT_EQ(R"( +-Frontend +-| ParseFunctionDefinition (fooB) +-| ParseFunctionDefinition (fooMTA) +-| ParseFunctionDefinition (fooA) +-| ParseDeclarationOrFunctionDefinition (test.cc:3:5) +-| | ParseFunctionDefinition (user) +-| PerformPendingInstantiations +-| | InstantiateFunction (fooA, a.h:7) +-| | | InstantiateFunction (fooB, b.h:3) +-| | | InstantiateFunction (fooMTA, a.h:4) +-)", +- buildTraceGraph(Json)); ++ // NOTE: If this test is failing, run this test with ++ // `llvm::errs() << TraceGraph;` and change the assert above. + } -+ case le32: return "le32"; -+ case le64: return "le64"; + TEST(TimeProfilerTest, ConstantEvaluationC99) { +- std::string Code = R"( ++ constexpr StringRef Code = R"( + struct { + short quantval[4]; // 3rd line + } value; +@@ -290,12 +218,15 @@ + setupProfiler(); + ASSERT_TRUE(compileFromString(Code, "-std=c99", "test.c")); + std::string Json = teardownProfiler(); +- ASSERT_EQ(R"( ++ std::string TraceGraph = buildTraceGraph(Json); ++ ASSERT_TRUE(TraceGraph == R"( + Frontend + | ParseDeclarationOrFunctionDefinition (test.c:2:1) + | | isIntegerConstantExpr () + | | EvaluateKnownConstIntCheckOverflow () + | PerformPendingInstantiations +-)", +- buildTraceGraph(Json)); ++)"); + - case amdil: - case amdil64: return "amdil"; - -@@ -427,6 +432,8 @@ - .Case("xcore", xcore) - .Case("nvptx", nvptx) - .Case("nvptx64", nvptx64) -+ .Case("le32", le32) -+ .Case("le64", le64) - .Case("amdil", amdil) - .Case("amdil64", amdil64) - .Case("hsail", hsail) -@@ -567,6 +574,8 @@ - .Case("xcore", Triple::xcore) - .Case("nvptx", Triple::nvptx) - .Case("nvptx64", Triple::nvptx64) -+ .Case("le32", Triple::le32) -+ .Case("le64", Triple::le64) - .Case("amdil", Triple::amdil) - .Case("amdil64", Triple::amdil64) - .Case("hsail", Triple::hsail) -@@ -896,6 +905,8 @@ - case Triple::hsail: - case Triple::kalimba: - case Triple::lanai: -+ case Triple::le32: -+ case Triple::le64: - case Triple::loongarch32: - case Triple::loongarch64: - case Triple::m68k: -@@ -1592,6 +1603,7 @@ - case llvm::Triple::hsail: - case llvm::Triple::kalimba: - case llvm::Triple::lanai: -+ case llvm::Triple::le32: - case llvm::Triple::loongarch32: - case llvm::Triple::m68k: - case llvm::Triple::mips: -@@ -1624,6 +1636,7 @@ - case llvm::Triple::bpfeb: - case llvm::Triple::bpfel: - case llvm::Triple::hsail64: -+ case llvm::Triple::le64: - case llvm::Triple::loongarch64: - case llvm::Triple::mips64: - case llvm::Triple::mips64el: -@@ -1682,6 +1695,7 @@ - case Triple::hsail: - case Triple::kalimba: - case Triple::lanai: -+ case Triple::le32: - case Triple::loongarch32: - case Triple::m68k: - case Triple::mips: -@@ -1712,6 +1726,7 @@ - case Triple::aarch64_be: T.setArch(Triple::armeb); break; - case Triple::amdil64: T.setArch(Triple::amdil); break; - case Triple::hsail64: T.setArch(Triple::hsail); break; -+ case Triple::le64: T.setArch(Triple::le32); break; - case Triple::loongarch64: T.setArch(Triple::loongarch32); break; - case Triple::mips64: - T.setArch(Triple::mips, getSubArch()); -@@ -1766,6 +1781,7 @@ - case Triple::bpfeb: - case Triple::bpfel: - case Triple::hsail64: -+ case Triple::le64: - case Triple::loongarch64: - case Triple::mips64: - case Triple::mips64el: -@@ -1789,6 +1805,7 @@ - case Triple::arm: T.setArch(Triple::aarch64); break; - case Triple::armeb: T.setArch(Triple::aarch64_be); break; - case Triple::hsail: T.setArch(Triple::hsail64); break; -+ case Triple::le32: T.setArch(Triple::le64); break; - case Triple::loongarch32: T.setArch(Triple::loongarch64); break; - case Triple::mips: - T.setArch(Triple::mips64, getSubArch()); -@@ -1831,6 +1848,8 @@ - case Triple::hsail64: - case Triple::hsail: - case Triple::kalimba: -+ case Triple::le32: -+ case Triple::le64: - case Triple::loongarch32: - case Triple::loongarch64: - case Triple::msp430: -@@ -1934,6 +1953,8 @@ - case Triple::hsail64: - case Triple::hsail: - case Triple::kalimba: -+ case Triple::le32: -+ case Triple::le64: - case Triple::loongarch32: - case Triple::loongarch64: - case Triple::mips64el: -diff -ruN --strip-trailing-cr a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp ---- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp -+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp -@@ -6995,7 +6995,7 @@ - // Ignore ephemeral values. - CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore); ++ // NOTE: If this test is failing, run this test with ++ // `llvm::errs() << TraceGraph;` and change the assert above. + } +diff -ruN --strip-trailing-cr a/lld/test/MachO/reproduce-thin-archive-objc.s b/lld/test/MachO/reproduce-thin-archive-objc.s +--- a/lld/test/MachO/reproduce-thin-archive-objc.s ++++ b/lld/test/MachO/reproduce-thin-archive-objc.s +@@ -4,20 +4,19 @@ + ## during linking. However, we need to iterate over all members for -ObjC, check that we don't + ## crash when we encounter a missing member. + +-# RUN: rm -rf %t; mkdir %t +-# RUN: sed s/SYM/_main/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o %t/main.o +-# RUN: sed s/SYM/_unused/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o %t/unused.o ++# RUN: rm -rf %t && mkdir %t && cd %t ++# RUN: sed s/SYM/_main/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o main.o ++# RUN: sed s/SYM/_unused/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o unused.o + +-# RUN: cd %t; llvm-ar rcsT unused.a unused.o; rm unused.o ++# RUN: llvm-ar rcsT unused.a unused.o; rm unused.o + ## FIXME: Absolute paths don't end up relativized in the repro file. + + # RUN: %no-fatal-warnings-lld %t/main.o %t/unused.a -ObjC -o /dev/null 2>&1 \ + # RUN: | FileCheck %s --check-prefix=WARN + +-# RUN: %lld %t/main.o %t/unused.a -ObjC --no-warn-thin-archive-missing-members -o /dev/null \ +-# RUN: | FileCheck %s --implicit-check-not 'warning' --allow-empty ++# RUN: %lld main.o unused.a -ObjC --no-warn-thin-archive-missing-members 2>&1 | count 0 + +-# WARN: ld64.lld: warning: {{.*}}unused.a: -ObjC failed to open archive member: 'unused.o' ++# WARN: warning: {{.*}}unused.a: -ObjC failed to open archive member: 'unused.o' + + .text + .globl SYM +diff -ruN --strip-trailing-cr a/llvm/include/llvm/Support/TimeProfiler.h b/llvm/include/llvm/Support/TimeProfiler.h +--- a/llvm/include/llvm/Support/TimeProfiler.h ++++ b/llvm/include/llvm/Support/TimeProfiler.h +@@ -83,28 +83,16 @@ + + class raw_pwrite_stream; + +-struct TimeTraceMetadata { +- std::string Detail; +- // Source file and line number information for the event. +- std::string File; +- int Line; +- +- bool isEmpty() const { return Detail.empty() && File.empty(); } +-}; +- + struct TimeTraceProfiler; + TimeTraceProfiler *getTimeTraceProfilerInstance(); -- SmallSetVector DeadInterleavePointerOps; -+ SmallVector DeadInterleavePointerOps; - for (BasicBlock *BB : TheLoop->blocks()) - for (Instruction &I : *BB) { - // Find all stores to invariant variables. Since they are going to sink -@@ -7013,7 +7013,7 @@ - if (Group->getInsertPos() == &I) - continue; - Value *PointerOp = getLoadStorePointerOperand(&I); -- DeadInterleavePointerOps.insert(PointerOp); -+ DeadInterleavePointerOps.push_back(PointerOp); - } - } +-bool isTimeTraceVerbose(); +- + struct TimeTraceProfilerEntry; + + /// Initialize the time trace profiler. + /// This sets up the global \p TimeTraceProfilerInstance + /// variable to be the profiler instance. + void timeTraceProfilerInitialize(unsigned TimeTraceGranularity, +- StringRef ProcName, +- bool TimeTraceVerbose = false); ++ StringRef ProcName); + + /// Cleanup the time trace profiler, if it was initialized. + void timeTraceProfilerCleanup(); +@@ -140,10 +128,6 @@ + timeTraceProfilerBegin(StringRef Name, + llvm::function_ref Detail); + +-TimeTraceProfilerEntry * +-timeTraceProfilerBegin(StringRef Name, +- llvm::function_ref MetaData); +- + /// Manually begin a time section, with the given \p Name and \p Detail. + /// This starts Async Events having \p Name as a category which is shown + /// separately from other traces. See +@@ -180,11 +164,6 @@ + if (getTimeTraceProfilerInstance() != nullptr) + Entry = timeTraceProfilerBegin(Name, Detail); + } +- TimeTraceScope(StringRef Name, +- llvm::function_ref Metadata) { +- if (getTimeTraceProfilerInstance() != nullptr) +- Entry = timeTraceProfilerBegin(Name, Metadata); +- } + ~TimeTraceScope() { + if (getTimeTraceProfilerInstance() != nullptr) + timeTraceProfilerEnd(Entry); +diff -ruN --strip-trailing-cr a/llvm/lib/Support/TimeProfiler.cpp b/llvm/lib/Support/TimeProfiler.cpp +--- a/llvm/lib/Support/TimeProfiler.cpp ++++ b/llvm/lib/Support/TimeProfiler.cpp +@@ -73,20 +73,12 @@ + const TimePointType Start; + TimePointType End; + const std::string Name; +- TimeTraceMetadata Metadata; +- ++ const std::string Detail; + const bool AsyncEvent = false; + TimeTraceProfilerEntry(TimePointType &&S, TimePointType &&E, std::string &&N, + std::string &&Dt, bool Ae) +- : Start(std::move(S)), End(std::move(E)), Name(std::move(N)), Metadata(), +- AsyncEvent(Ae) { +- Metadata.Detail = std::move(Dt); +- } +- +- TimeTraceProfilerEntry(TimePointType &&S, TimePointType &&E, std::string &&N, +- TimeTraceMetadata &&Mt, bool Ae) + : Start(std::move(S)), End(std::move(E)), Name(std::move(N)), +- Metadata(std::move(Mt)), AsyncEvent(Ae) {} ++ Detail(std::move(Dt)), AsyncEvent(Ae) {} + + // Calculate timings for FlameGraph. Cast time points to microsecond precision + // rather than casting duration. This avoids truncation issues causing inner +@@ -105,12 +97,10 @@ + }; -@@ -7029,7 +7029,7 @@ - })) - continue; - VecValuesToIgnore.insert(Op); -- DeadInterleavePointerOps.insert(Op->op_begin(), Op->op_end()); -+ DeadInterleavePointerOps.append(Op->op_begin(), Op->op_end()); + struct llvm::TimeTraceProfiler { +- TimeTraceProfiler(unsigned TimeTraceGranularity = 0, StringRef ProcName = "", +- bool TimeTraceVerbose = false) ++ TimeTraceProfiler(unsigned TimeTraceGranularity = 0, StringRef ProcName = "") + : BeginningOfTime(system_clock::now()), StartTime(ClockType::now()), + ProcName(ProcName), Pid(sys::Process::getProcessId()), +- Tid(llvm::get_threadid()), TimeTraceGranularity(TimeTraceGranularity), +- TimeTraceVerbose(TimeTraceVerbose) { ++ Tid(llvm::get_threadid()), TimeTraceGranularity(TimeTraceGranularity) { + llvm::get_thread_name(ThreadName); } - // Ignore type-promoting instructions we identified during reduction -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll b/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll ---- a/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll -+++ b/llvm/test/CodeGen/AArch64/aarch64-mull-masks.ll -@@ -281,7 +281,8 @@ - ; CHECK-LABEL: smull_ldrsw_shift: - ; CHECK: // %bb.0: // %entry - ; CHECK-NEXT: ldrsw x8, [x0] --; CHECK-NEXT: smull x0, w8, w1 -+; CHECK-NEXT: sxtw x9, w1 -+; CHECK-NEXT: smull x0, w8, w9 - ; CHECK-NEXT: ret - entry: - %ext64 = load i32, ptr %x0 -@@ -489,7 +490,8 @@ - ; CHECK-LABEL: smaddl_ldrsw_shift: - ; CHECK: // %bb.0: // %entry - ; CHECK-NEXT: ldrsw x8, [x0] --; CHECK-NEXT: smaddl x0, w8, w1, x2 -+; CHECK-NEXT: sxtw x9, w1 -+; CHECK-NEXT: smaddl x0, w8, w9, x2 - ; CHECK-NEXT: ret - entry: - %ext64 = load i32, ptr %x0 -@@ -652,7 +654,8 @@ - ; CHECK-LABEL: smnegl_ldrsw_shift: - ; CHECK: // %bb.0: // %entry - ; CHECK-NEXT: ldrsw x8, [x0] --; CHECK-NEXT: smnegl x0, w8, w1 -+; CHECK-NEXT: sxtw x9, w1 -+; CHECK-NEXT: smnegl x0, w8, w9 - ; CHECK-NEXT: ret - entry: - %ext64 = load i32, ptr %x0 -@@ -815,7 +818,8 @@ - ; CHECK-LABEL: smsubl_ldrsw_shift: - ; CHECK: // %bb.0: // %entry - ; CHECK-NEXT: ldrsw x8, [x0] --; CHECK-NEXT: smsubl x0, w8, w1, x2 -+; CHECK-NEXT: sxtw x9, w1 -+; CHECK-NEXT: smsubl x0, w8, w9, x2 - ; CHECK-NEXT: ret - entry: - %ext64 = load i32, ptr %x0 -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll b/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll ---- a/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll -+++ b/llvm/test/Transforms/LoopVectorize/X86/interleave-cost.ll -@@ -182,9 +182,432 @@ - exit: - ret void - } -+ -+define void @geps_feeding_interleave_groups_with_reuse(ptr %arg, i64 %arg1, ptr %arg2) #0 { -+; CHECK-LABEL: define void @geps_feeding_interleave_groups_with_reuse( -+; CHECK-SAME: ptr [[ARG:%.*]], i64 [[ARG1:%.*]], ptr [[ARG2:%.*]]) #[[ATTR0:[0-9]+]] { -+; CHECK-NEXT: [[ENTRY:.*]]: -+; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[ARG1]], 1 -+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 30 -+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_SCEVCHECK:.*]] -+; CHECK: [[VECTOR_SCEVCHECK]]: -+; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[ARG2]], i64 8 -+; CHECK-NEXT: [[MUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 16, i64 [[ARG1]]) -+; CHECK-NEXT: [[MUL_RESULT:%.*]] = extractvalue { i64, i1 } [[MUL]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW:%.*]] = extractvalue { i64, i1 } [[MUL]], 1 -+; CHECK-NEXT: [[TMP1:%.*]] = sub i64 0, [[MUL_RESULT]] -+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[SCEVGEP]], i64 [[MUL_RESULT]] -+; CHECK-NEXT: [[TMP3:%.*]] = icmp ult ptr [[TMP2]], [[SCEVGEP]] -+; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[MUL_OVERFLOW]] -+; CHECK-NEXT: [[SCEVGEP1:%.*]] = getelementptr i8, ptr [[ARG2]], i64 12 -+; CHECK-NEXT: [[MUL2:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 16, i64 [[ARG1]]) -+; CHECK-NEXT: [[MUL_RESULT3:%.*]] = extractvalue { i64, i1 } [[MUL2]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW4:%.*]] = extractvalue { i64, i1 } [[MUL2]], 1 -+; CHECK-NEXT: [[TMP5:%.*]] = sub i64 0, [[MUL_RESULT3]] -+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[SCEVGEP1]], i64 [[MUL_RESULT3]] -+; CHECK-NEXT: [[TMP7:%.*]] = icmp ult ptr [[TMP6]], [[SCEVGEP1]] -+; CHECK-NEXT: [[TMP8:%.*]] = or i1 [[TMP7]], [[MUL_OVERFLOW4]] -+; CHECK-NEXT: [[SCEVGEP5:%.*]] = getelementptr i8, ptr [[ARG2]], i64 4 -+; CHECK-NEXT: [[MUL6:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 16, i64 [[ARG1]]) -+; CHECK-NEXT: [[MUL_RESULT7:%.*]] = extractvalue { i64, i1 } [[MUL6]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW8:%.*]] = extractvalue { i64, i1 } [[MUL6]], 1 -+; CHECK-NEXT: [[TMP9:%.*]] = sub i64 0, [[MUL_RESULT7]] -+; CHECK-NEXT: [[TMP10:%.*]] = getelementptr i8, ptr [[SCEVGEP5]], i64 [[MUL_RESULT7]] -+; CHECK-NEXT: [[TMP11:%.*]] = icmp ult ptr [[TMP10]], [[SCEVGEP5]] -+; CHECK-NEXT: [[TMP12:%.*]] = or i1 [[TMP11]], [[MUL_OVERFLOW8]] -+; CHECK-NEXT: [[MUL9:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 16, i64 [[ARG1]]) -+; CHECK-NEXT: [[MUL_RESULT10:%.*]] = extractvalue { i64, i1 } [[MUL9]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW11:%.*]] = extractvalue { i64, i1 } [[MUL9]], 1 -+; CHECK-NEXT: [[TMP13:%.*]] = sub i64 0, [[MUL_RESULT10]] -+; CHECK-NEXT: [[TMP14:%.*]] = getelementptr i8, ptr [[ARG2]], i64 [[MUL_RESULT10]] -+; CHECK-NEXT: [[TMP15:%.*]] = icmp ult ptr [[TMP14]], [[ARG2]] -+; CHECK-NEXT: [[TMP16:%.*]] = or i1 [[TMP15]], [[MUL_OVERFLOW11]] -+; CHECK-NEXT: [[TMP17:%.*]] = or i1 [[TMP4]], [[TMP8]] -+; CHECK-NEXT: [[TMP18:%.*]] = or i1 [[TMP17]], [[TMP12]] -+; CHECK-NEXT: [[TMP19:%.*]] = or i1 [[TMP18]], [[TMP16]] -+; CHECK-NEXT: br i1 [[TMP19]], label %[[SCALAR_PH]], label %[[VECTOR_MEMCHECK:.*]] -+; CHECK: [[VECTOR_MEMCHECK]]: -+; CHECK-NEXT: [[TMP20:%.*]] = shl i64 [[ARG1]], 4 -+; CHECK-NEXT: [[TMP21:%.*]] = add i64 [[TMP20]], 16 -+; CHECK-NEXT: [[SCEVGEP12:%.*]] = getelementptr i8, ptr [[ARG2]], i64 [[TMP21]] -+; CHECK-NEXT: [[TMP22:%.*]] = shl i64 [[ARG1]], 5 -+; CHECK-NEXT: [[TMP23:%.*]] = add i64 [[TMP22]], 32 -+; CHECK-NEXT: [[SCEVGEP13:%.*]] = getelementptr i8, ptr [[ARG]], i64 [[TMP23]] -+; CHECK-NEXT: [[BOUND0:%.*]] = icmp ult ptr [[ARG2]], [[SCEVGEP13]] -+; CHECK-NEXT: [[BOUND1:%.*]] = icmp ult ptr [[ARG]], [[SCEVGEP12]] -+; CHECK-NEXT: [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]] -+; CHECK-NEXT: br i1 [[FOUND_CONFLICT]], label %[[SCALAR_PH]], label %[[VECTOR_PH:.*]] -+; CHECK: [[VECTOR_PH]]: -+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 2 -+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]] -+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] -+; CHECK: [[VECTOR_BODY]]: -+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -+; CHECK-NEXT: [[TMP24:%.*]] = add i64 [[INDEX]], 0 -+; CHECK-NEXT: [[TMP25:%.*]] = shl i64 [[TMP24]], 5 -+; CHECK-NEXT: [[TMP26:%.*]] = getelementptr i8, ptr [[ARG]], i64 [[TMP25]] -+; CHECK-NEXT: [[TMP27:%.*]] = shl i64 [[TMP24]], 4 -+; CHECK-NEXT: [[TMP28:%.*]] = getelementptr i8, ptr [[ARG2]], i64 [[TMP27]] -+; CHECK-NEXT: [[TMP29:%.*]] = getelementptr float, ptr [[TMP26]], i32 0 -+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x float>, ptr [[TMP29]], align 4 -+; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC14:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC15:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC16:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC17:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC18:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC19:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[STRIDED_VEC20:%.*]] = shufflevector <16 x float> [[WIDE_VEC]], <16 x float> poison, <2 x i32> -+; CHECK-NEXT: [[TMP30:%.*]] = fadd <2 x float> [[STRIDED_VEC]], [[STRIDED_VEC17]] -+; CHECK-NEXT: [[TMP31:%.*]] = fmul <2 x float> [[TMP30]], zeroinitializer -+; CHECK-NEXT: [[TMP32:%.*]] = fadd <2 x float> [[STRIDED_VEC14]], [[STRIDED_VEC18]] -+; CHECK-NEXT: [[TMP33:%.*]] = fmul <2 x float> [[TMP32]], zeroinitializer -+; CHECK-NEXT: [[TMP34:%.*]] = fadd <2 x float> [[STRIDED_VEC15]], [[STRIDED_VEC19]] -+; CHECK-NEXT: [[TMP35:%.*]] = fmul <2 x float> [[TMP34]], zeroinitializer -+; CHECK-NEXT: [[TMP36:%.*]] = fadd <2 x float> [[STRIDED_VEC16]], [[STRIDED_VEC20]] -+; CHECK-NEXT: [[TMP37:%.*]] = fmul <2 x float> [[TMP36]], zeroinitializer -+; CHECK-NEXT: [[TMP38:%.*]] = getelementptr i8, ptr [[TMP28]], i64 12 -+; CHECK-NEXT: [[TMP39:%.*]] = getelementptr float, ptr [[TMP38]], i32 -3 -+; CHECK-NEXT: [[TMP40:%.*]] = shufflevector <2 x float> [[TMP31]], <2 x float> [[TMP33]], <4 x i32> -+; CHECK-NEXT: [[TMP41:%.*]] = shufflevector <2 x float> [[TMP35]], <2 x float> [[TMP37]], <4 x i32> -+; CHECK-NEXT: [[TMP42:%.*]] = shufflevector <4 x float> [[TMP40]], <4 x float> [[TMP41]], <8 x i32> -+; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <8 x float> [[TMP42]], <8 x float> poison, <8 x i32> -+; CHECK-NEXT: store <8 x float> [[INTERLEAVED_VEC]], ptr [[TMP39]], align 4 -+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2 -+; CHECK-NEXT: [[TMP43:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -+; CHECK-NEXT: br i1 [[TMP43]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] -+; CHECK: [[MIDDLE_BLOCK]]: -+; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]] -+; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]] -+; CHECK: [[SCALAR_PH]]: -+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ], [ 0, %[[VECTOR_SCEVCHECK]] ], [ 0, %[[VECTOR_MEMCHECK]] ] -+; CHECK-NEXT: br label %[[LOOP:.*]] -+; CHECK: [[LOOP]]: -+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ] -+; CHECK-NEXT: [[SHL_IV_5:%.*]] = shl i64 [[IV]], 5 -+; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr i8, ptr [[ARG]], i64 [[SHL_IV_5]] -+; CHECK-NEXT: [[ADD_5:%.*]] = or disjoint i64 [[SHL_IV_5]], 16 -+; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr i8, ptr [[ARG]], i64 [[ADD_5]] -+; CHECK-NEXT: [[SHL_IV_4:%.*]] = shl i64 [[IV]], 4 -+; CHECK-NEXT: [[GEP_3:%.*]] = getelementptr i8, ptr [[ARG2]], i64 [[SHL_IV_4]] -+; CHECK-NEXT: [[L_1:%.*]] = load float, ptr [[GEP_1]], align 4 -+; CHECK-NEXT: [[L_2:%.*]] = load float, ptr [[GEP_2]], align 4 -+; CHECK-NEXT: [[ADD_1:%.*]] = fadd float [[L_1]], [[L_2]] -+; CHECK-NEXT: [[MUL_1:%.*]] = fmul float [[ADD_1]], 0.000000e+00 -+; CHECK-NEXT: store float [[MUL_1]], ptr [[GEP_3]], align 4 -+; CHECK-NEXT: [[GEP_4:%.*]] = getelementptr i8, ptr [[GEP_1]], i64 4 -+; CHECK-NEXT: [[L_3:%.*]] = load float, ptr [[GEP_4]], align 4 -+; CHECK-NEXT: [[GEP_5:%.*]] = getelementptr i8, ptr [[GEP_2]], i64 4 -+; CHECK-NEXT: [[L_4:%.*]] = load float, ptr [[GEP_5]], align 4 -+; CHECK-NEXT: [[ADD_2:%.*]] = fadd float [[L_3]], [[L_4]] -+; CHECK-NEXT: [[MUL_2:%.*]] = fmul float [[ADD_2]], 0.000000e+00 -+; CHECK-NEXT: [[GEP_6:%.*]] = getelementptr i8, ptr [[GEP_3]], i64 4 -+; CHECK-NEXT: store float [[MUL_2]], ptr [[GEP_6]], align 4 -+; CHECK-NEXT: [[GEP_7:%.*]] = getelementptr i8, ptr [[GEP_1]], i64 8 -+; CHECK-NEXT: [[L_5:%.*]] = load float, ptr [[GEP_7]], align 4 -+; CHECK-NEXT: [[GEP_8:%.*]] = getelementptr i8, ptr [[GEP_2]], i64 8 -+; CHECK-NEXT: [[L_6:%.*]] = load float, ptr [[GEP_8]], align 4 -+; CHECK-NEXT: [[ADD_3:%.*]] = fadd float [[L_5]], [[L_6]] -+; CHECK-NEXT: [[MUL_3:%.*]] = fmul float [[ADD_3]], 0.000000e+00 -+; CHECK-NEXT: [[GEP_9:%.*]] = getelementptr i8, ptr [[GEP_3]], i64 8 -+; CHECK-NEXT: store float [[MUL_3]], ptr [[GEP_9]], align 4 -+; CHECK-NEXT: [[I27:%.*]] = getelementptr i8, ptr [[GEP_1]], i64 12 -+; CHECK-NEXT: [[L_7:%.*]] = load float, ptr [[I27]], align 4 -+; CHECK-NEXT: [[GEP_10:%.*]] = getelementptr i8, ptr [[GEP_2]], i64 12 -+; CHECK-NEXT: [[L_8:%.*]] = load float, ptr [[GEP_10]], align 4 -+; CHECK-NEXT: [[ADD_4:%.*]] = fadd float [[L_7]], [[L_8]] -+; CHECK-NEXT: [[MUL_4:%.*]] = fmul float [[ADD_4]], 0.000000e+00 -+; CHECK-NEXT: [[GEP_11:%.*]] = getelementptr i8, ptr [[GEP_3]], i64 12 -+; CHECK-NEXT: store float [[MUL_4]], ptr [[GEP_11]], align 4 -+; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1 -+; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV]], [[ARG1]] -+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]] -+; CHECK: [[EXIT]]: -+; CHECK-NEXT: ret void -+; -+entry: -+ br label %loop -+ -+loop: -+ %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] -+ %shl.iv.5 = shl i64 %iv, 5 -+ %gep.1 = getelementptr i8, ptr %arg, i64 %shl.iv.5 -+ %add.5 = or disjoint i64 %shl.iv.5, 16 -+ %gep.2 = getelementptr i8, ptr %arg, i64 %add.5 -+ %shl.iv.4 = shl i64 %iv, 4 -+ %gep.3 = getelementptr i8, ptr %arg2, i64 %shl.iv.4 -+ %l.1 = load float, ptr %gep.1, align 4 -+ %l.2 = load float, ptr %gep.2, align 4 -+ %add.1 = fadd float %l.1, %l.2 -+ %mul.1 = fmul float %add.1, 0.000000e+00 -+ store float %mul.1, ptr %gep.3, align 4 -+ %gep.4 = getelementptr i8, ptr %gep.1, i64 4 -+ %l.3 = load float, ptr %gep.4, align 4 -+ %gep.5 = getelementptr i8, ptr %gep.2, i64 4 -+ %l.4 = load float, ptr %gep.5, align 4 -+ %add.2 = fadd float %l.3, %l.4 -+ %mul.2 = fmul float %add.2, 0.000000e+00 -+ %gep.6 = getelementptr i8, ptr %gep.3, i64 4 -+ store float %mul.2, ptr %gep.6, align 4 -+ %gep.7 = getelementptr i8, ptr %gep.1, i64 8 -+ %l.5 = load float, ptr %gep.7, align 4 -+ %gep.8 = getelementptr i8, ptr %gep.2, i64 8 -+ %l.6 = load float, ptr %gep.8, align 4 -+ %add.3 = fadd float %l.5, %l.6 -+ %mul.3 = fmul float %add.3, 0.000000e+00 -+ %gep.9 = getelementptr i8, ptr %gep.3, i64 8 -+ store float %mul.3, ptr %gep.9, align 4 -+ %i27 = getelementptr i8, ptr %gep.1, i64 12 -+ %l.7 = load float, ptr %i27, align 4 -+ %gep.10 = getelementptr i8, ptr %gep.2, i64 12 -+ %l.8 = load float, ptr %gep.10, align 4 -+ %add.4 = fadd float %l.7, %l.8 -+ %mul.4 = fmul float %add.4, 0.000000e+00 -+ %gep.11 = getelementptr i8, ptr %gep.3, i64 12 -+ store float %mul.4, ptr %gep.11, align 4 -+ %iv.next = add i64 %iv, 1 -+ %ec = icmp eq i64 %iv, %arg1 -+ br i1 %ec, label %exit, label %loop -+ -+exit: -+ ret void -+} -+ -+define void @geps_feeding_interleave_groups_with_reuse2(ptr %A, ptr %B, i64 %N) #1 { -+; CHECK-LABEL: define void @geps_feeding_interleave_groups_with_reuse2( -+; CHECK-SAME: ptr [[A:%.*]], ptr [[B:%.*]], i64 [[N:%.*]]) #[[ATTR1:[0-9]+]] { -+; CHECK-NEXT: [[ENTRY:.*]]: -+; CHECK-NEXT: [[TMP0:%.*]] = lshr i64 [[N]], 3 -+; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[TMP0]], 1 -+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 28 -+; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_SCEVCHECK:.*]] -+; CHECK: [[VECTOR_SCEVCHECK]]: -+; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[N]], 3 -+; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[A]], i64 24 -+; CHECK-NEXT: [[MUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT:%.*]] = extractvalue { i64, i1 } [[MUL]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW:%.*]] = extractvalue { i64, i1 } [[MUL]], 1 -+; CHECK-NEXT: [[TMP3:%.*]] = sub i64 0, [[MUL_RESULT]] -+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[SCEVGEP]], i64 [[MUL_RESULT]] -+; CHECK-NEXT: [[TMP5:%.*]] = icmp ult ptr [[TMP4]], [[SCEVGEP]] -+; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP5]], [[MUL_OVERFLOW]] -+; CHECK-NEXT: [[SCEVGEP1:%.*]] = getelementptr i8, ptr [[A]], i64 28 -+; CHECK-NEXT: [[MUL2:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT3:%.*]] = extractvalue { i64, i1 } [[MUL2]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW4:%.*]] = extractvalue { i64, i1 } [[MUL2]], 1 -+; CHECK-NEXT: [[TMP7:%.*]] = sub i64 0, [[MUL_RESULT3]] -+; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[SCEVGEP1]], i64 [[MUL_RESULT3]] -+; CHECK-NEXT: [[TMP9:%.*]] = icmp ult ptr [[TMP8]], [[SCEVGEP1]] -+; CHECK-NEXT: [[TMP10:%.*]] = or i1 [[TMP9]], [[MUL_OVERFLOW4]] -+; CHECK-NEXT: [[SCEVGEP5:%.*]] = getelementptr i8, ptr [[A]], i64 20 -+; CHECK-NEXT: [[MUL6:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT7:%.*]] = extractvalue { i64, i1 } [[MUL6]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW8:%.*]] = extractvalue { i64, i1 } [[MUL6]], 1 -+; CHECK-NEXT: [[TMP11:%.*]] = sub i64 0, [[MUL_RESULT7]] -+; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i8, ptr [[SCEVGEP5]], i64 [[MUL_RESULT7]] -+; CHECK-NEXT: [[TMP13:%.*]] = icmp ult ptr [[TMP12]], [[SCEVGEP5]] -+; CHECK-NEXT: [[TMP14:%.*]] = or i1 [[TMP13]], [[MUL_OVERFLOW8]] -+; CHECK-NEXT: [[SCEVGEP9:%.*]] = getelementptr i8, ptr [[A]], i64 16 -+; CHECK-NEXT: [[MUL10:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT11:%.*]] = extractvalue { i64, i1 } [[MUL10]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW12:%.*]] = extractvalue { i64, i1 } [[MUL10]], 1 -+; CHECK-NEXT: [[TMP15:%.*]] = sub i64 0, [[MUL_RESULT11]] -+; CHECK-NEXT: [[TMP16:%.*]] = getelementptr i8, ptr [[SCEVGEP9]], i64 [[MUL_RESULT11]] -+; CHECK-NEXT: [[TMP17:%.*]] = icmp ult ptr [[TMP16]], [[SCEVGEP9]] -+; CHECK-NEXT: [[TMP18:%.*]] = or i1 [[TMP17]], [[MUL_OVERFLOW12]] -+; CHECK-NEXT: [[SCEVGEP13:%.*]] = getelementptr i8, ptr [[A]], i64 12 -+; CHECK-NEXT: [[MUL14:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT15:%.*]] = extractvalue { i64, i1 } [[MUL14]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW16:%.*]] = extractvalue { i64, i1 } [[MUL14]], 1 -+; CHECK-NEXT: [[TMP19:%.*]] = sub i64 0, [[MUL_RESULT15]] -+; CHECK-NEXT: [[TMP20:%.*]] = getelementptr i8, ptr [[SCEVGEP13]], i64 [[MUL_RESULT15]] -+; CHECK-NEXT: [[TMP21:%.*]] = icmp ult ptr [[TMP20]], [[SCEVGEP13]] -+; CHECK-NEXT: [[TMP22:%.*]] = or i1 [[TMP21]], [[MUL_OVERFLOW16]] -+; CHECK-NEXT: [[SCEVGEP17:%.*]] = getelementptr i8, ptr [[A]], i64 8 -+; CHECK-NEXT: [[MUL18:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT19:%.*]] = extractvalue { i64, i1 } [[MUL18]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW20:%.*]] = extractvalue { i64, i1 } [[MUL18]], 1 -+; CHECK-NEXT: [[TMP23:%.*]] = sub i64 0, [[MUL_RESULT19]] -+; CHECK-NEXT: [[TMP24:%.*]] = getelementptr i8, ptr [[SCEVGEP17]], i64 [[MUL_RESULT19]] -+; CHECK-NEXT: [[TMP25:%.*]] = icmp ult ptr [[TMP24]], [[SCEVGEP17]] -+; CHECK-NEXT: [[TMP26:%.*]] = or i1 [[TMP25]], [[MUL_OVERFLOW20]] -+; CHECK-NEXT: [[SCEVGEP21:%.*]] = getelementptr i8, ptr [[A]], i64 4 -+; CHECK-NEXT: [[MUL22:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT23:%.*]] = extractvalue { i64, i1 } [[MUL22]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW24:%.*]] = extractvalue { i64, i1 } [[MUL22]], 1 -+; CHECK-NEXT: [[TMP27:%.*]] = sub i64 0, [[MUL_RESULT23]] -+; CHECK-NEXT: [[TMP28:%.*]] = getelementptr i8, ptr [[SCEVGEP21]], i64 [[MUL_RESULT23]] -+; CHECK-NEXT: [[TMP29:%.*]] = icmp ult ptr [[TMP28]], [[SCEVGEP21]] -+; CHECK-NEXT: [[TMP30:%.*]] = or i1 [[TMP29]], [[MUL_OVERFLOW24]] -+; CHECK-NEXT: [[MUL25:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 32, i64 [[TMP2]]) -+; CHECK-NEXT: [[MUL_RESULT26:%.*]] = extractvalue { i64, i1 } [[MUL25]], 0 -+; CHECK-NEXT: [[MUL_OVERFLOW27:%.*]] = extractvalue { i64, i1 } [[MUL25]], 1 -+; CHECK-NEXT: [[TMP31:%.*]] = sub i64 0, [[MUL_RESULT26]] -+; CHECK-NEXT: [[TMP32:%.*]] = getelementptr i8, ptr [[A]], i64 [[MUL_RESULT26]] -+; CHECK-NEXT: [[TMP33:%.*]] = icmp ult ptr [[TMP32]], [[A]] -+; CHECK-NEXT: [[TMP34:%.*]] = or i1 [[TMP33]], [[MUL_OVERFLOW27]] -+; CHECK-NEXT: [[TMP35:%.*]] = or i1 [[TMP6]], [[TMP10]] -+; CHECK-NEXT: [[TMP36:%.*]] = or i1 [[TMP35]], [[TMP14]] -+; CHECK-NEXT: [[TMP37:%.*]] = or i1 [[TMP36]], [[TMP18]] -+; CHECK-NEXT: [[TMP38:%.*]] = or i1 [[TMP37]], [[TMP22]] -+; CHECK-NEXT: [[TMP39:%.*]] = or i1 [[TMP38]], [[TMP26]] -+; CHECK-NEXT: [[TMP40:%.*]] = or i1 [[TMP39]], [[TMP30]] -+; CHECK-NEXT: [[TMP41:%.*]] = or i1 [[TMP40]], [[TMP34]] -+; CHECK-NEXT: br i1 [[TMP41]], label %[[SCALAR_PH]], label %[[VECTOR_MEMCHECK:.*]] -+; CHECK: [[VECTOR_MEMCHECK]]: -+; CHECK-NEXT: [[TMP42:%.*]] = lshr i64 [[N]], 3 -+; CHECK-NEXT: [[TMP43:%.*]] = shl i64 [[TMP42]], 5 -+; CHECK-NEXT: [[TMP44:%.*]] = add i64 [[TMP43]], 32 -+; CHECK-NEXT: [[SCEVGEP28:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP44]] -+; CHECK-NEXT: [[TMP45:%.*]] = add nuw nsw i64 [[TMP43]], 4 -+; CHECK-NEXT: [[SCEVGEP29:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP45]] -+; CHECK-NEXT: [[TMP46:%.*]] = shl i64 [[TMP42]], 4 -+; CHECK-NEXT: [[TMP47:%.*]] = add nuw nsw i64 [[TMP46]], 8 -+; CHECK-NEXT: [[SCEVGEP30:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP47]] -+; CHECK-NEXT: [[BOUND0:%.*]] = icmp ult ptr [[A]], [[SCEVGEP29]] -+; CHECK-NEXT: [[BOUND1:%.*]] = icmp ult ptr [[B]], [[SCEVGEP28]] -+; CHECK-NEXT: [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]] -+; CHECK-NEXT: [[BOUND031:%.*]] = icmp ult ptr [[A]], [[SCEVGEP30]] -+; CHECK-NEXT: [[BOUND132:%.*]] = icmp ult ptr [[B]], [[SCEVGEP28]] -+; CHECK-NEXT: [[FOUND_CONFLICT33:%.*]] = and i1 [[BOUND031]], [[BOUND132]] -+; CHECK-NEXT: [[CONFLICT_RDX:%.*]] = or i1 [[FOUND_CONFLICT]], [[FOUND_CONFLICT33]] -+; CHECK-NEXT: br i1 [[CONFLICT_RDX]], label %[[SCALAR_PH]], label %[[VECTOR_PH:.*]] -+; CHECK: [[VECTOR_PH]]: -+; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], 4 -+; CHECK-NEXT: [[TMP48:%.*]] = icmp eq i64 [[N_MOD_VF]], 0 -+; CHECK-NEXT: [[TMP49:%.*]] = select i1 [[TMP48]], i64 4, i64 [[N_MOD_VF]] -+; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP1]], [[TMP49]] -+; CHECK-NEXT: [[IND_END:%.*]] = mul i64 [[N_VEC]], 8 -+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] -+; CHECK: [[VECTOR_BODY]]: -+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] -+; CHECK-NEXT: [[VEC_IND:%.*]] = phi <4 x i64> [ , %[[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], %[[VECTOR_BODY]] ] -+; CHECK-NEXT: [[OFFSET_IDX:%.*]] = mul i64 [[INDEX]], 8 -+; CHECK-NEXT: [[TMP50:%.*]] = add i64 [[OFFSET_IDX]], 0 -+; CHECK-NEXT: [[TMP51:%.*]] = lshr exact i64 [[TMP50]], 1 -+; CHECK-NEXT: [[TMP52:%.*]] = getelementptr i32, ptr [[B]], i64 [[TMP51]] -+; CHECK-NEXT: [[TMP53:%.*]] = getelementptr i32, ptr [[TMP52]], i32 0 -+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x i32>, ptr [[TMP53]], align 4 -+; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x i32> [[WIDE_VEC]], <16 x i32> poison, <4 x i32> -+; CHECK-NEXT: [[STRIDED_VEC34:%.*]] = shufflevector <16 x i32> [[WIDE_VEC]], <16 x i32> poison, <4 x i32> -+; CHECK-NEXT: [[TMP54:%.*]] = getelementptr i32, ptr [[B]], <4 x i64> [[VEC_IND]] -+; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP54]], i32 4, <4 x i1> , <4 x i32> poison), !alias.scope [[META6:![0-9]+]] -+; CHECK-NEXT: [[TMP55:%.*]] = or disjoint i64 [[TMP50]], 7 -+; CHECK-NEXT: [[TMP56:%.*]] = getelementptr i32, ptr [[A]], i64 [[TMP55]] -+; CHECK-NEXT: [[TMP57:%.*]] = getelementptr i32, ptr [[TMP56]], i32 -7 -+; CHECK-NEXT: [[TMP58:%.*]] = shufflevector <4 x i32> [[STRIDED_VEC]], <4 x i32> zeroinitializer, <8 x i32> -+; CHECK-NEXT: [[TMP59:%.*]] = shufflevector <4 x i32> [[STRIDED_VEC34]], <4 x i32> zeroinitializer, <8 x i32> -+; CHECK-NEXT: [[TMP60:%.*]] = shufflevector <4 x i32> [[WIDE_MASKED_GATHER]], <4 x i32> zeroinitializer, <8 x i32> -+; CHECK-NEXT: [[TMP61:%.*]] = shufflevector <8 x i32> [[TMP58]], <8 x i32> [[TMP59]], <16 x i32> -+; CHECK-NEXT: [[TMP62:%.*]] = shufflevector <8 x i32> [[TMP60]], <8 x i32> zeroinitializer, <16 x i32> -+; CHECK-NEXT: [[TMP63:%.*]] = shufflevector <16 x i32> [[TMP61]], <16 x i32> [[TMP62]], <32 x i32> -+; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <32 x i32> [[TMP63]], <32 x i32> poison, <32 x i32> -+; CHECK-NEXT: store <32 x i32> [[INTERLEAVED_VEC]], ptr [[TMP57]], align 4 -+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 -+; CHECK-NEXT: [[VEC_IND_NEXT]] = add <4 x i64> [[VEC_IND]], -+; CHECK-NEXT: [[TMP64:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -+; CHECK-NEXT: br i1 [[TMP64]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]] -+; CHECK: [[MIDDLE_BLOCK]]: -+; CHECK-NEXT: br label %[[SCALAR_PH]] -+; CHECK: [[SCALAR_PH]]: -+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ], [ 0, %[[VECTOR_SCEVCHECK]] ], [ 0, %[[VECTOR_MEMCHECK]] ] -+; CHECK-NEXT: br label %[[LOOP:.*]] -+; CHECK: [[LOOP]]: -+; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT_7:%.*]], %[[LOOP]] ] -+; CHECK-NEXT: [[SHR_1:%.*]] = lshr exact i64 [[IV]], 1 -+; CHECK-NEXT: [[GEP_B:%.*]] = getelementptr nusw i32, ptr [[B]], i64 [[SHR_1]] -+; CHECK-NEXT: [[L:%.*]] = load i32, ptr [[GEP_B]], align 4 -+; CHECK-NEXT: [[GEP_A:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV]] -+; CHECK-NEXT: store i32 [[L]], ptr [[GEP_A]], align 4 -+; CHECK-NEXT: [[IV_NEXT:%.*]] = or disjoint i64 [[IV]], 1 -+; CHECK-NEXT: [[GEP_A_1:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_1]], align 4 -+; CHECK-NEXT: [[IV_NEXT_1:%.*]] = or disjoint i64 [[IV]], 2 -+; CHECK-NEXT: [[SHR_2:%.*]] = lshr exact i64 [[IV_NEXT_1]], 1 -+; CHECK-NEXT: [[GEP_B_2:%.*]] = getelementptr i32, ptr [[B]], i64 [[SHR_2]] -+; CHECK-NEXT: [[TMP65:%.*]] = load i32, ptr [[GEP_B_2]], align 4 -+; CHECK-NEXT: [[GEP_A_2:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_1]] -+; CHECK-NEXT: store i32 [[TMP65]], ptr [[GEP_A_2]], align 4 -+; CHECK-NEXT: [[IV_NEXT_2:%.*]] = or disjoint i64 [[IV]], 3 -+; CHECK-NEXT: [[GEP_A_3:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_2]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_3]], align 4 -+; CHECK-NEXT: [[IV_NEXT_3:%.*]] = or disjoint i64 [[IV]], 4 -+; CHECK-NEXT: [[GEP_B_4:%.*]] = getelementptr i32, ptr [[B]], i64 [[IV]] -+; CHECK-NEXT: [[TMP66:%.*]] = load i32, ptr [[GEP_B_4]], align 4 -+; CHECK-NEXT: [[GEP_A_4:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_3]] -+; CHECK-NEXT: store i32 [[TMP66]], ptr [[GEP_A_4]], align 4 -+; CHECK-NEXT: [[IV_NEXT_4:%.*]] = or disjoint i64 [[IV]], 5 -+; CHECK-NEXT: [[GEP_A_5:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_4]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_5]], align 4 -+; CHECK-NEXT: [[IV_NEXT_5:%.*]] = or disjoint i64 [[IV]], 6 -+; CHECK-NEXT: [[GEP_A_6:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_5]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_6]], align 4 -+; CHECK-NEXT: [[IV_NEXT_6:%.*]] = or disjoint i64 [[IV]], 7 -+; CHECK-NEXT: [[GEP_A_7:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV_NEXT_6]] -+; CHECK-NEXT: store i32 0, ptr [[GEP_A_7]], align 4 -+; CHECK-NEXT: [[IV_NEXT_7]] = add nuw nsw i64 [[IV]], 8 -+; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV]], [[N]] -+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT:.*]], label %[[LOOP]], !llvm.loop [[LOOP10:![0-9]+]] -+; CHECK: [[EXIT]]: -+; CHECK-NEXT: ret void -+; -+entry: -+ br label %loop -+ -+loop: -+ %iv = phi i64 [ 0, %entry ], [ %iv.next.7, %loop ] -+ %shr.1 = lshr exact i64 %iv, 1 -+ %gep.B = getelementptr nusw i32, ptr %B, i64 %shr.1 -+ %l = load i32, ptr %gep.B, align 4 -+ %gep.A = getelementptr i32, ptr %A, i64 %iv -+ store i32 %l, ptr %gep.A, align 4 -+ %iv.next = or disjoint i64 %iv, 1 -+ %gep.A.1 = getelementptr i32, ptr %A, i64 %iv.next -+ store i32 0, ptr %gep.A.1, align 4 -+ %iv.next.1 = or disjoint i64 %iv, 2 -+ %shr.2 = lshr exact i64 %iv.next.1, 1 -+ %gep.B.2 = getelementptr i32, ptr %B, i64 %shr.2 -+ %1 = load i32, ptr %gep.B.2, align 4 -+ %gep.A.2 = getelementptr i32, ptr %A, i64 %iv.next.1 -+ store i32 %1, ptr %gep.A.2, align 4 -+ %iv.next.2 = or disjoint i64 %iv, 3 -+ %gep.A.3 = getelementptr i32, ptr %A, i64 %iv.next.2 -+ store i32 0, ptr %gep.A.3, align 4 -+ %iv.next.3 = or disjoint i64 %iv, 4 -+ %gep.B.4 = getelementptr i32, ptr %B, i64 %iv -+ %2 = load i32, ptr %gep.B.4, align 4 -+ %gep.A.4 = getelementptr i32, ptr %A, i64 %iv.next.3 -+ store i32 %2, ptr %gep.A.4, align 4 -+ %iv.next.4 = or disjoint i64 %iv, 5 -+ %gep.A.5 = getelementptr i32, ptr %A, i64 %iv.next.4 -+ store i32 0, ptr %gep.A.5, align 4 -+ %iv.next.5 = or disjoint i64 %iv, 6 -+ %gep.A.6 = getelementptr i32, ptr %A, i64 %iv.next.5 -+ store i32 0, ptr %gep.A.6, align 4 -+ %iv.next.6 = or disjoint i64 %iv, 7 -+ %gep.A.7 = getelementptr i32, ptr %A, i64 %iv.next.6 -+ store i32 0, ptr %gep.A.7, align 4 -+ %iv.next.7 = add nuw nsw i64 %iv, 8 -+ %ec = icmp eq i64 %iv, %N -+ br i1 %ec, label %exit, label %loop -+ -+exit: -+ ret void -+} -+ -+attributes #0 = { "target-features"="+sse4.2" } -+attributes #1 = { "min-legal-vector-width"="0" "target-cpu"="cascadelake" } -+ - ;. - ; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]} - ; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1} - ; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"} - ; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META1]]} -+; CHECK: [[LOOP4]] = distinct !{[[LOOP4]], [[META1]], [[META2]]} -+; CHECK: [[LOOP5]] = distinct !{[[LOOP5]], [[META1]]} -+; CHECK: [[META6]] = !{[[META7:![0-9]+]]} -+; CHECK: [[META7]] = distinct !{[[META7]], [[META8:![0-9]+]]} -+; CHECK: [[META8]] = distinct !{[[META8]], !"LVerDomain"} -+; CHECK: [[LOOP9]] = distinct !{[[LOOP9]], [[META1]], [[META2]]} -+; CHECK: [[LOOP10]] = distinct !{[[LOOP10]], [[META1]]} - ;. -diff -ruN --strip-trailing-cr a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp ---- a/llvm/unittests/IR/PatternMatch.cpp -+++ b/llvm/unittests/IR/PatternMatch.cpp -@@ -2235,7 +2235,7 @@ - MutableConstTestTypes; - TYPED_TEST_SUITE(MutableConstTest, MutableConstTestTypes, ); +@@ -123,15 +113,6 @@ + return Stack.back().get(); + } --TYPED_TEST(MutableConstTest, ICmp) { -+TYPED_TEST(MutableConstTest, /* FIXME: UAR bug */ DISABLED_ICmp) { - auto &IRB = PatternMatchTest::IRB; +- TimeTraceProfilerEntry * +- begin(std::string Name, llvm::function_ref Metadata, +- bool AsyncEvent = false) { +- Stack.emplace_back(std::make_unique( +- ClockType::now(), TimePointType(), std::move(Name), Metadata(), +- AsyncEvent)); +- return Stack.back().get(); +- } +- + void end() { + assert(!Stack.empty() && "Must call begin() first"); + end(*Stack.back()); +@@ -203,15 +184,8 @@ + J.attribute("dur", DurUs); + } + J.attribute("name", E.Name); +- if (!E.Metadata.isEmpty()) { +- J.attributeObject("args", [&] { +- if (!E.Metadata.Detail.empty()) +- J.attribute("detail", E.Metadata.Detail); +- if (!E.Metadata.File.empty()) +- J.attribute("file", E.Metadata.File); +- if (E.Metadata.Line > 0) +- J.attribute("line", E.Metadata.Line); +- }); ++ if (!E.Detail.empty()) { ++ J.attributeObject("args", [&] { J.attribute("detail", E.Detail); }); + } + }); + +@@ -333,25 +307,14 @@ + + // Minimum time granularity (in microseconds) + const unsigned TimeTraceGranularity; +- +- // Make time trace capture verbose event details (e.g. source filenames). This +- // can increase the size of the output by 2-3 times. +- const bool TimeTraceVerbose; + }; - typedef std::tuple_element_t<0, TypeParam> ValueType; -@@ -2319,7 +2319,7 @@ - .match((InstructionType)IRB.CreateICmp(Pred, L, R))); +-bool llvm::isTimeTraceVerbose() { +- return getTimeTraceProfilerInstance() && +- getTimeTraceProfilerInstance()->TimeTraceVerbose; +-} +- + void llvm::timeTraceProfilerInitialize(unsigned TimeTraceGranularity, +- StringRef ProcName, +- bool TimeTraceVerbose) { ++ StringRef ProcName) { + assert(TimeTraceProfilerInstance == nullptr && + "Profiler should not be initialized"); + TimeTraceProfilerInstance = new TimeTraceProfiler( +- TimeTraceGranularity, llvm::sys::path::filename(ProcName), +- TimeTraceVerbose); ++ TimeTraceGranularity, llvm::sys::path::filename(ProcName)); } --TYPED_TEST(MutableConstTest, FCmp) { -+TYPED_TEST(MutableConstTest, /* FIXME: UAR bug */ DISABLED_FCmp) { - auto &IRB = PatternMatchTest::IRB; - - typedef std::tuple_element_t<0, TypeParam> ValueType; -diff -ruN --strip-trailing-cr a/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn b/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn ---- a/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn -+++ b/llvm/utils/gn/secondary/clang/lib/Basic/BUILD.gn -@@ -108,6 +108,7 @@ - "Targets/DirectX.cpp", - "Targets/Hexagon.cpp", - "Targets/Lanai.cpp", -+ "Targets/Le64.cpp", - "Targets/LoongArch.cpp", - "Targets/M68k.cpp", - "Targets/MSP430.cpp", -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -@@ -254,6 +254,7 @@ - hdrs = ["src/__support/macros/optimization.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ":__support_macros_properties_compiler", - ], - ) -@@ -261,6 +262,9 @@ - libc_support_library( - name = "__support_macros_sanitizer", - hdrs = ["src/__support/macros/sanitizer.h"], -+ deps = [ -+ ":__support_macros_config", -+ ], - ) - - libc_support_library( -@@ -271,6 +275,7 @@ - ], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ":__support_macros_properties_architectures", - ], - ) -@@ -280,6 +285,7 @@ - hdrs = ["src/__support/CPP/algorithm.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ], - ) - -@@ -317,6 +323,7 @@ - hdrs = ["src/__support/CPP/bitset.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ], - ) - -@@ -334,6 +341,7 @@ - hdrs = ["src/__support/CPP/expected.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ], - ) - -@@ -424,6 +432,7 @@ - ], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ":__support_macros_properties_types", - ":llvm_libc_macros_stdfix_macros", - ], -@@ -573,7 +582,10 @@ - libc_support_library( - name = "__support_str_to_num_result", - hdrs = ["src/__support/str_to_num_result.h"], -- deps = [":__support_macros_attributes"], -+ deps = [ -+ ":__support_macros_attributes", -+ ":__support_macros_config", -+ ], - ) - - libc_support_library( -@@ -612,7 +624,10 @@ - libc_support_library( - name = "__support_ctype_utils", - hdrs = ["src/__support/ctype_utils.h"], -- deps = [":__support_macros_attributes"], -+ deps = [ -+ ":__support_macros_attributes", -+ ":__support_macros_config", -+ ], - ) - - libc_support_library( -@@ -785,6 +800,7 @@ - hdrs = ["src/__support/FPUtil/rounding_mode.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ":hdr_fenv_macros", - ], - ) -@@ -1126,6 +1142,7 @@ - hdrs = ["src/__support/threads/sleep.h"], - deps = [ - ":__support_macros_attributes", -+ ":__support_macros_config", - ], - ) - -@@ -3408,9 +3425,9 @@ - ":__support_arg_list", - ":__support_file_file", - ":__support_macros_attributes", -- ":types_FILE", - ":printf_main", - ":printf_writer", -+ ":types_FILE", - ], - ) - -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/libc_build_rules.bzl b/utils/bazel/llvm-project-overlay/libc/libc_build_rules.bzl ---- a/utils/bazel/llvm-project-overlay/libc/libc_build_rules.bzl -+++ b/utils/bazel/llvm-project-overlay/libc/libc_build_rules.bzl -@@ -43,7 +43,7 @@ - name = name, - copts = copts + libc_common_copts(), - local_defines = local_defines + LIBC_CONFIGURE_OPTIONS, -- deps = deps + ["//libc:__support_macros_config"], -+ deps = deps, - linkstatic = 1, - **kwargs - ) -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/test/src/math/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/test/src/math/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/test/src/math/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/test/src/math/BUILD.bazel -@@ -298,8 +298,8 @@ - "//libc:__support_fputil_fp_bits", - "//libc:__support_fputil_manipulation_functions", - "//libc:hdr_math_macros", -- "//libc/test/UnitTest:fp_test_helpers", - "//libc/test/UnitTest:LibcUnitTest", -+ "//libc/test/UnitTest:fp_test_helpers", - ], - ) - -@@ -559,7 +559,10 @@ - libc_support_library( - name = "sdcomp26094", - hdrs = ["sdcomp26094.h"], -- deps = ["//libc:__support_cpp_array"], -+ deps = [ -+ "//libc:__support_cpp_array", -+ "//libc:__support_macros_config", -+ ], - ) - - math_test( -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/test/src/string/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/test/src/string/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/test/src/string/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/test/src/string/BUILD.bazel -@@ -121,6 +121,7 @@ - deps = [ - "//libc:__support_cpp_span", - "//libc:__support_libc_assert", -+ "//libc:__support_macros_config", - "//libc:__support_macros_sanitizer", - "//libc:string_memory_utils", - ], -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/test/UnitTest/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/test/UnitTest/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/test/UnitTest/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/test/UnitTest/BUILD.bazel -@@ -18,6 +18,7 @@ - "//libc:__support_big_int", - "//libc:__support_cpp_string", - "//libc:__support_cpp_string_view", -+ "//libc:__support_macros_config", - "//libc:__support_macros_properties_types", - "//libc:__support_osutil_io", - "//libc:__support_uint128", -@@ -52,6 +53,7 @@ - "//libc:__support_fputil_fp_bits", - "//libc:__support_fputil_fpbits_str", - "//libc:__support_fputil_rounding_mode", -+ "//libc:__support_macros_config", - "//libc:__support_macros_properties_architectures", - "//libc:__support_macros_properties_types", - "//libc:__support_stringutil", -@@ -89,10 +91,11 @@ - "//libc:__support_fputil_fp_bits", - "//libc:__support_fputil_fpbits_str", - "//libc:__support_fputil_rounding_mode", -+ "//libc:__support_macros_config", - "//libc:__support_macros_properties_architectures", -+ "//libc:hdr_fenv_macros", - "//libc:hdr_math_macros", -- "//libc:hdr_fenv_macros", -- "//libc:types_fenv_t", -+ "//libc:types_fenv_t", - ], - ) - -@@ -110,6 +113,7 @@ - "//libc:__support_cpp_bitset", - "//libc:__support_cpp_span", - "//libc:__support_cpp_type_traits", -+ "//libc:__support_macros_config", - ], - ) + // Removes all TimeTraceProfilerInstances. +@@ -418,14 +381,6 @@ + return nullptr; + } -@@ -125,6 +129,7 @@ - ":LibcUnitTest", - ":string_utils", - "//libc:__support_fputil_fp_bits", -+ "//libc:__support_macros_config", - "//libc:printf_core_structs", - ], - ) -@@ -138,5 +143,6 @@ - "//libc:__support_big_int", - "//libc:__support_cpp_string", - "//libc:__support_cpp_type_traits", -+ "//libc:__support_macros_config", - ], - ) -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/utils/MPFRWrapper/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/utils/MPFRWrapper/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/utils/MPFRWrapper/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/utils/MPFRWrapper/BUILD.bazel -@@ -48,6 +48,7 @@ - "//libc:__support_cpp_type_traits", - "//libc:__support_fputil_fp_bits", - "//libc:__support_fputil_fpbits_str", -+ "//libc:__support_macros_config", - "//libc:__support_macros_properties_types", - "//libc:hdr_math_macros", - "//libc/test/UnitTest:LibcUnitTest", -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -2900,6 +2900,7 @@ - ":IR", - ":LoopLikeInterface", - ":SCFDialect", -+ ":SCFToControlFlow", - ":SCFTransformOpsIncGen", - ":SCFTransforms", - ":SCFUtils", +-TimeTraceProfilerEntry * +-llvm::timeTraceProfilerBegin(StringRef Name, +- llvm::function_ref Metadata) { +- if (TimeTraceProfilerInstance != nullptr) +- return TimeTraceProfilerInstance->begin(std::string(Name), Metadata, false); +- return nullptr; +-} +- + TimeTraceProfilerEntry *llvm::timeTraceAsyncProfilerBegin(StringRef Name, + StringRef Detail) { + if (TimeTraceProfilerInstance != nullptr) diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index a108e965dd0086..2949b73a155af1 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "acc159aea1e641e3694ab8fe5faa231788077011" - LLVM_SHA256 = "ff2d0c2d9dd22eb39b3d135bcf0cf91008b395de797f543e32790df372945d13" + LLVM_COMMIT = "84658fb82b67fc22ecba1560d0cddd09f9104178" + LLVM_SHA256 = "b4a50d36a8ab0284f7022f61bbf07a2fb3ea25c6bb2cc422d2418c23b61366da" tf_http_archive( name = name, From 84175078b2486bfb55f1a27947d3067f29dad651 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 24 Jul 2024 08:27:44 -0700 Subject: [PATCH 117/376] [XLA] Replace Reduce(Broadcast(Scalar)) with Broadcast(Multiply(Scalar)) when the reduction operation is addition PiperOrigin-RevId: 655576695 --- xla/service/algebraic_simplifier.cc | 23 ++++++++++++++++++++++ xla/service/algebraic_simplifier.h | 2 +- xla/service/algebraic_simplifier_test.cc | 25 ++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index a60e3576a67944..fad9dcacaa4ab9 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -7881,6 +7881,29 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } } + // Replace Reduce(Broadcast(Scalar)) with Broadcast(Multiply(Scalar)) when the + // reduction operation is addition + if (arg->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(arg->operand(0)->shape())) { + if (Match(reduce->to_apply()->root_instruction(), + m::AddAnyOrder(m::Parameter(0), m::Parameter(1))) && + IsScalarConstantZero(init_value)) { + int64_t reduction_dims_prod = 1; + for (auto i : reduce->dimensions()) { + reduction_dims_prod *= arg->shape().dimensions(i); + } + + HloInstruction* multiplier = + MakeScalarLike(arg->mutable_operand(0), reduction_dims_prod); + TF_ASSIGN_OR_RETURN(HloInstruction * multiplied_scalar, + MakeBinaryHlo(HloOpcode::kMultiply, + arg->mutable_operand(0), multiplier)); + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateBroadcast(reduce->shape(), + multiplied_scalar, {})); + } + } + return absl::OkStatus(); } diff --git a/xla/service/algebraic_simplifier.h b/xla/service/algebraic_simplifier.h index 7261872d40ceda..bdd4f915a54ec9 100644 --- a/xla/service/algebraic_simplifier.h +++ b/xla/service/algebraic_simplifier.h @@ -289,7 +289,7 @@ class AlgebraicSimplifierOptions { private: // Metadata struct can be used to store any metadata information encapsulated - // with the AlgebraicSimplierOptions that can be later used in an + // with the AlgebraicSimplifierOptions that can be later used in an // AlgebraicSimplifier pass. For example, // cudnn_batchnorm_forward_training_metadata can be used to store the name of // a custom call. If the custom call is diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index 03bde1910e4892..5b8e4db491c13b 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -11754,5 +11754,30 @@ TEST_F(AlgebraicSimplifierTest, KeepInt4ConvertConstant) { ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); } +TEST_F(AlgebraicSimplifierTest, ReduceBroadcastScalarToBroadcastMultiply) { + const std::string hlo_string = R"( + HloModule module + add_bf16 { + x = bf16[] parameter(0) + y = bf16[] parameter(1) + ROOT sum = bf16[] add(x, y) + } + + ENTRY test { + a = bf16[] parameter(0) + negate = bf16[] negate(a) + broadcast = bf16[2,5,11,17,19] broadcast(negate), dimensions={} + zero = bf16[] constant(0) + ROOT reduce = bf16[2] reduce(broadcast, zero), + dimensions={1,2,3,4}, to_apply=add_bf16 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + HloInstruction* root = m->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMultiply); +} + } // namespace } // namespace xla From 6d9074d45e532607ab0ff23f7f800e469b5bd5e3 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Wed, 24 Jul 2024 09:18:31 -0700 Subject: [PATCH 118/376] Move the LLVM integration patch to sparsity patches It is only impacting sparsity, and not Triton itself, so we should not be upstreaming it to Triton. Rather, we should just figure out how to properly fix it in our sparsity extension. PiperOrigin-RevId: 655591896 --- .../triton/llvm_integration/cl654795065.patch | 15 ------ .../triton/llvm_integration/series.bzl | 1 - third_party/triton/xla_extensions/series.bzl | 1 + .../xla_extensions/sparsity_layout.patch | 51 +++++++++++++++++++ 4 files changed, 52 insertions(+), 16 deletions(-) delete mode 100644 third_party/triton/llvm_integration/cl654795065.patch create mode 100644 third_party/triton/xla_extensions/sparsity_layout.patch diff --git a/third_party/triton/llvm_integration/cl654795065.patch b/third_party/triton/llvm_integration/cl654795065.patch deleted file mode 100644 index 19ac00d2cdb637..00000000000000 --- a/third_party/triton/llvm_integration/cl654795065.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp ---- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -+++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeCon - addArgumentMaterialization([&](OpBuilder &builder, - RankedTensorType tensorType, ValueRange inputs, - Location loc) -> std::optional { -- llvm_unreachable("Argument rematerialization should not happen in Triton " -- "-> TritonGPU conversion"); -+ // TODO(b/354860562): reenable or remove. -+ // llvm_unreachable("Argument rematerialization should not happen in Triton " -+ // "-> TritonGPU conversion"); - return std::nullopt; - }); - diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 7b438990166a30..656b9c894904d8 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,6 +8,5 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton/llvm_integration:cl654795065.patch", # Add new patches just above this line ] diff --git a/third_party/triton/xla_extensions/series.bzl b/third_party/triton/xla_extensions/series.bzl index 757c2b95a1be4a..19ba85b57b3672 100644 --- a/third_party/triton/xla_extensions/series.bzl +++ b/third_party/triton/xla_extensions/series.bzl @@ -8,5 +8,6 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to extensions_files_patch_list = [ "//third_party/triton/xla_extensions:sparse_dot.patch", # Sparsity internal patch + "//third_party/triton/xla_extensions:sparsity_layout.patch", # Sparsity internal patch # Add new patches just above this line ] diff --git a/third_party/triton/xla_extensions/sparsity_layout.patch b/third_party/triton/xla_extensions/sparsity_layout.patch new file mode 100644 index 00000000000000..b64ddbdbdab683 --- /dev/null +++ b/third_party/triton/xla_extensions/sparsity_layout.patch @@ -0,0 +1,51 @@ +diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +index 34fb89954..a0172e107 100644 +--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp ++++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> std::optional { +- llvm_unreachable("Argument rematerialization should not happen in Triton " +- "-> TritonGPU conversion"); ++ // TODO(b/354860562): reenable or remove. ++ // llvm_unreachable("Argument rematerialization should not happen in Triton " ++ // "-> TritonGPU conversion"); + return std::nullopt; + }); + +@@ -67,6 +68,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, + Location loc) -> std::optional { ++ // Allows partial TTIR to TTGIR conversion by materializing a conversion for ++ // remaining uses of values that have been converted to a new type. ++ // We use this to rewrite triton_gpu.sparse_dot in a separate pass after ++ // 'convert-triton-to-tritongpu'. ++ return builder.create(loc, tensorType, ++ inputs); + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonGPU Conversion"); + return std::nullopt; +diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp +index df3d3b042..e38c184f6 100644 +--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp ++++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp +@@ -2867,13 +2867,13 @@ struct CanonicalizeConvertFromConvert + // heuristic to accommodate fused attention. + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); +- if (mlir::isa(dstType.getEncoding()) && +- mlir::isa(srcType.getEncoding())) ++ if (mlir::isa_and_nonnull(dstType.getEncoding()) && ++ mlir::isa_and_nonnull(srcType.getEncoding())) + return failure(); + + // for hopper MMAv3 +- if (mlir::isa(dstType.getEncoding()) && +- mlir::isa(srcType.getEncoding()) && ++ if (mlir::isa_and_nonnull(dstType.getEncoding()) && ++ mlir::isa_and_nonnull(srcType.getEncoding()) && + llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { + return dot->hasTrait(); + })) { From b7727532fbd08bd198902bd909cce9f67bc64cb8 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 24 Jul 2024 10:55:12 -0700 Subject: [PATCH 119/376] Update `@rules_python` to 0.34.0 after LLVM integrate Should fix XLA/JAX CI PiperOrigin-RevId: 655626888 --- third_party/py/python_init_rules.bzl | 6 +++--- third_party/tsl/third_party/py/python_init_rules.bzl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/py/python_init_rules.bzl b/third_party/py/python_init_rules.bzl index 98a7b8bc3c315a..4e1473e5342c4a 100644 --- a/third_party/py/python_init_rules.bzl +++ b/third_party/py/python_init_rules.bzl @@ -5,7 +5,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") def python_init_rules(): http_archive( name = "rules_python", - sha256 = "9d04041ac92a0985e344235f5d946f71ac543f1b1565f2cdbc9a2aaee8adf55b", - strip_prefix = "rules_python-0.26.0", - url = "https://github.com/bazelbuild/rules_python/releases/download/0.26.0/rules_python-0.26.0.tar.gz", + sha256 = "778aaeab3e6cfd56d681c89f5c10d7ad6bf8d2f1a72de9de55b23081b2d31618", + strip_prefix = "rules_python-0.34.0", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.34.0/rules_python-0.34.0.tar.gz", ) diff --git a/third_party/tsl/third_party/py/python_init_rules.bzl b/third_party/tsl/third_party/py/python_init_rules.bzl index 98a7b8bc3c315a..4e1473e5342c4a 100644 --- a/third_party/tsl/third_party/py/python_init_rules.bzl +++ b/third_party/tsl/third_party/py/python_init_rules.bzl @@ -5,7 +5,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") def python_init_rules(): http_archive( name = "rules_python", - sha256 = "9d04041ac92a0985e344235f5d946f71ac543f1b1565f2cdbc9a2aaee8adf55b", - strip_prefix = "rules_python-0.26.0", - url = "https://github.com/bazelbuild/rules_python/releases/download/0.26.0/rules_python-0.26.0.tar.gz", + sha256 = "778aaeab3e6cfd56d681c89f5c10d7ad6bf8d2f1a72de9de55b23081b2d31618", + strip_prefix = "rules_python-0.34.0", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.34.0/rules_python-0.34.0.tar.gz", ) From 099348bee8da99bf6570bdd9793f02e13a4bdc59 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Wed, 24 Jul 2024 11:40:55 -0700 Subject: [PATCH 120/376] [xla] Add use_shardy_partitioner as a field in ExecutableOptions. Add a test. PiperOrigin-RevId: 655644662 --- xla/client/executable_build_options.cc | 4 ++++ xla/client/executable_build_options.h | 8 ++++++++ xla/hlo/ir/hlo_module.cc | 2 ++ xla/pjrt/compile_options.proto | 7 ++++++- xla/service/hlo_module_config.cc | 3 +++ xla/service/hlo_module_config.h | 7 +++++++ xla/service/hlo_module_util.cc | 2 ++ xla/tests/local_client_execute_test.cc | 23 +++++++++++++++++++++++ xla/xla.proto | 10 ++++++++-- 9 files changed, 63 insertions(+), 3 deletions(-) diff --git a/xla/client/executable_build_options.cc b/xla/client/executable_build_options.cc index 77c7791d151bbe..46b810d5537a3a 100644 --- a/xla/client/executable_build_options.cc +++ b/xla/client/executable_build_options.cc @@ -196,6 +196,7 @@ absl::StatusOr ExecutableBuildOptions::ToProto() for (int64_t s : auto_spmd_partitioning_mesh_ids()) { output.mutable_auto_spmd_partitioning_mesh_ids()->Add(s); } + output.set_use_shardy_partitioner(use_shardy_partitioner()); return output; } @@ -242,6 +243,7 @@ absl::StatusOr ExecutableBuildOptionsFromProto( output.set_auto_spmd_partitioning_mesh_ids( std::vector(input.auto_spmd_partitioning_mesh_ids().begin(), input.auto_spmd_partitioning_mesh_ids().end())); + output.set_use_shardy_partitioner(input.use_shardy_partitioner()); return output; } @@ -300,6 +302,8 @@ ExecutionOptions CreateExecutionOptions( execution_options.set_fdo_profile(build_options.fdo_profile().data(), build_options.fdo_profile().size()); execution_options.set_device_memory_size(build_options.device_memory_size()); + execution_options.set_use_shardy_partitioner( + build_options.use_shardy_partitioner()); return execution_options; } diff --git a/xla/client/executable_build_options.h b/xla/client/executable_build_options.h index e9c3bdec4c6694..c849230ecad082 100644 --- a/xla/client/executable_build_options.h +++ b/xla/client/executable_build_options.h @@ -231,6 +231,13 @@ class ExecutableBuildOptions { return *this; } + bool use_shardy_partitioner() const { return use_shardy_partitioner_; } + ExecutableBuildOptions& set_use_shardy_partitioner( + bool use_shardy_partitioner) { + use_shardy_partitioner_ = use_shardy_partitioner; + return *this; + } + // Returns a string representation of the build options, suitable for // debugging. std::string ToString() const; @@ -279,6 +286,7 @@ class ExecutableBuildOptions { LayoutCanonicalizationCallback layout_canonicalization_callback_; std::string fdo_profile_; int64_t device_memory_size_ = 0; + bool use_shardy_partitioner_ = false; int process_index_ = 0; int process_count_ = 1; std::shared_ptr key_value_store_; diff --git a/xla/hlo/ir/hlo_module.cc b/xla/hlo/ir/hlo_module.cc index efa2ac6c5fa0f4..0711d49ef63e16 100644 --- a/xla/hlo/ir/hlo_module.cc +++ b/xla/hlo/ir/hlo_module.cc @@ -739,6 +739,8 @@ absl::StatusOr HloModule::CreateModuleConfigFromShape( execution_options->allow_separate_sharding_programs()); HloModuleConfig::AssignStructShardableValueUpdatePairs( module_config, execution_options->shardable_value_update_pairs()); + module_config.set_use_shardy_partitioner( + execution_options->use_shardy_partitioner()); } // The module config is constructed with default layouts regardless of what is diff --git a/xla/pjrt/compile_options.proto b/xla/pjrt/compile_options.proto index 4ea4af933e9367..bd23ca73f6244c 100644 --- a/xla/pjrt/compile_options.proto +++ b/xla/pjrt/compile_options.proto @@ -7,7 +7,7 @@ import "xla/xla.proto"; import "xla/xla_data.proto"; // A serialization of xla::ExecutableBuildOptions. -// Next id: 19. +// Next id: 20. message ExecutableBuildOptionsProto { // If set, this is the device to build the computation for. Valid // device_ordinal values are: 0 to # of devices - 1. These values are @@ -102,6 +102,11 @@ message ExecutableBuildOptionsProto { // Mesh ids in auto sharding options. repeated int64 auto_spmd_partitioning_mesh_ids = 17; + + // Use Shardy, a new partitioner, to replace the existing + // ShardingPropagation and SpmdPartitioner. See go/xla-sdy-pipeline for + // details. + bool use_shardy_partitioner = 19; } message OptionOverrideProto { diff --git a/xla/service/hlo_module_config.cc b/xla/service/hlo_module_config.cc index 1e970e0e907071..a5400c866c63d0 100644 --- a/xla/service/hlo_module_config.cc +++ b/xla/service/hlo_module_config.cc @@ -97,6 +97,7 @@ std::string HloModuleConfig::compilation_cache_key() const { if (device_memory_size() != 0) { StrAppend(&key, "::device_memory_size=", device_memory_size()); } + StrAppend(&key, "::use_shardy_partitioner=", use_shardy_partitioner()); return key; } @@ -321,6 +322,7 @@ HloModuleConfigProto HloModuleConfig::ToProto() const { proto.set_allow_separate_sharding_programs(allow_separate_sharding_programs_); proto.set_fdo_profile(fdo_profile_); proto.set_device_memory_size(device_memory_size_); + proto.set_use_shardy_partitioner(use_shardy_partitioner_); return proto; } @@ -390,6 +392,7 @@ HloModuleConfig::CreateFromProto(const HloModuleConfigProto& proto) { proto.allow_separate_sharding_programs(); config->fdo_profile_ = proto.fdo_profile(); config->device_memory_size_ = proto.device_memory_size(); + config->use_shardy_partitioner_ = proto.use_shardy_partitioner(); return std::move(config); } diff --git a/xla/service/hlo_module_config.h b/xla/service/hlo_module_config.h index 05ce603c36b034..a428c9bccba7a0 100644 --- a/xla/service/hlo_module_config.h +++ b/xla/service/hlo_module_config.h @@ -375,6 +375,11 @@ class HloModuleConfig { device_memory_size_ = device_memory_size; } + bool use_shardy_partitioner() const { return use_shardy_partitioner_; } + void set_use_shardy_partitioner(bool use_shardy_partitioner) { + use_shardy_partitioner_ = use_shardy_partitioner; + } + private: // If you add new members, be sure to update compilation_cache_key and the // HloModuleConfigProto. @@ -501,6 +506,8 @@ class HloModuleConfig { std::string fdo_profile_; int64_t device_memory_size_ = 0; + + bool use_shardy_partitioner_ = false; // LINT.ThenChange(//tensorflow/compiler/xla/xla.proto) }; diff --git a/xla/service/hlo_module_util.cc b/xla/service/hlo_module_util.cc index 87bcdd63ada57e..ca67634cd1e14d 100644 --- a/xla/service/hlo_module_util.cc +++ b/xla/service/hlo_module_util.cc @@ -132,6 +132,8 @@ absl::StatusOr> CreateModuleConfig( execution_options->alias_passthrough_params()); *config->mutable_fdo_profile() = execution_options->fdo_profile(); config->set_device_memory_size(execution_options->device_memory_size()); + config->set_use_shardy_partitioner( + execution_options->use_shardy_partitioner()); } else { config->set_replica_count(default_num_replicas); config->set_debug_options(GetDebugOptionsFromFlags()); diff --git a/xla/tests/local_client_execute_test.cc b/xla/tests/local_client_execute_test.cc index 259d2ef7363279..b4e4a167a4d07d 100644 --- a/xla/tests/local_client_execute_test.cc +++ b/xla/tests/local_client_execute_test.cc @@ -995,6 +995,29 @@ XLA_TEST_F(LocalClientExecuteTest, ValidateDeviceMemorySize) { EXPECT_EQ(proto.config().device_memory_size(), kDeviceMemorySize); } +XLA_TEST_F(LocalClientExecuteTest, ValidateUseShardyPartitioner) { + XlaBuilder builder(TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); + Shape argument_layout = + local_client_->backend().compiler()->DefaultDeviceShapeRepresentation( + ShapeUtil::MakeShapeWithDenseLayout(F32, /*dimensions=*/{3}, {0})); + + ExecutableBuildOptions build_options; + build_options.set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, + local_client_->Compile(builder.Build().value(), {&argument_layout}, + build_options)); + EXPECT_EQ(1, executables.size()); + const HloModule& compiled_module = + executables.front()->executable()->module(); + EXPECT_EQ(compiled_module.config().use_shardy_partitioner(), true); + auto proto = compiled_module.ToProtoWithConfig(); + EXPECT_EQ(proto.config().use_shardy_partitioner(), true); +} + BENCHMARK(BM_LocalClientOverhead); } // namespace diff --git a/xla/xla.proto b/xla/xla.proto index dc232e2941edf2..ad53966df160da 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -911,7 +911,7 @@ message ShardableValueUpdatePairProto { // will have an effect on every platform. // // When adding new fields, keep in mind that boolean fields default to false. -// Next id: 24. +// Next id: 25. message ExecutionOptions { // This optional field's layout is used as a hint when storing the output of // this computation. Subsequent transfers of this output array to the client @@ -1018,12 +1018,17 @@ message ExecutionOptions { // Amount of device memory available for the executable to use. int64 device_memory_size = 22; + + // Use Shardy, a new partitioner, to replace the existing + // ShardingPropagation and SpmdPartitioner. See go/xla-sdy-pipeline for + // details. + bool use_shardy_partitioner = 24; } // Serialization of HloModuleConfig. See the C++ class definition for // descriptions of each field. // There are no guarantees of backwards or forwards compatibility. -// Next id: 34. +// Next id: 35. message HloModuleConfigProto { enum FusionConfigCollection { OFF = 0; // Do not collect configuration. @@ -1077,6 +1082,7 @@ message HloModuleConfigProto { xla.PrecisionConfig.Precision matrix_unit_operand_precision = 29; bytes fdo_profile = 31; int64 device_memory_size = 32; + bool use_shardy_partitioner = 34; } message HloModuleProtoWithConfig { From f81c01e8d8d063c94b03c91c0fac7ce86bab5e25 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 24 Jul 2024 12:22:11 -0700 Subject: [PATCH 121/376] [XLA:MSA] Implement an auxiliary function (SimulateComputeInstruction) to simulate processing outstanding async copy instructions. When executing a computation instruction, there may be a time windows when default memory is not required (e.g., computation intensive instructions). This default memory idle time window can be used to process outstanding async copy instructions. We simulate this process in the ProcessAsyncCopyInTimeWindow function. This function tries to process outstanding async copy instructions that stored in default_read queue and default_write queue. To provide a more general interface, I wrap this function with a high-level function (SimulateComputeInstruction), which accepts a compute instruction as input. PiperOrigin-RevId: 655659156 --- xla/service/memory_space_assignment/BUILD | 1 + .../memory_space_assignment/simulator.cc | 59 ++++++++ .../memory_space_assignment/simulator.h | 32 ++++- .../memory_space_assignment/simulator_test.cc | 136 +++++++++++++++++- 4 files changed, 220 insertions(+), 8 deletions(-) diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index b886ce775abf9e..cd7a0c0f1163af 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -336,6 +336,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", ], ) diff --git a/xla/service/memory_space_assignment/simulator.cc b/xla/service/memory_space_assignment/simulator.cc index 761dae2983366f..7cd50834e16d29 100644 --- a/xla/service/memory_space_assignment/simulator.cc +++ b/xla/service/memory_space_assignment/simulator.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_live_range.h" @@ -210,5 +211,63 @@ float RuntimeSimulator::SimulateAsyncCopyDone( return elapsed_time; }; +float RuntimeSimulator::SimulateComputeInstruction( + const HloInstruction* instruction, + absl::Span> + operands_in_alternate_memory, + absl::Span outputs_in_alternate_memory) { + // Calculate the time in which the instruction does not access the default + // memory. + float default_memory_idle_time = + cost_analysis_->GetDefaultMemoryBandwidthIdleTime( + *instruction, operands_in_alternate_memory, + outputs_in_alternate_memory); + + // Execute the outstanding async copy in the idle time. + ProcessAsyncCopiesInIdleTime(default_memory_idle_time); + + float inst_elapsed = cost_analysis_->GetInstructionElapsedInAlternateMemory( + *instruction, operands_in_alternate_memory, outputs_in_alternate_memory); + return inst_elapsed; +} + +void RuntimeSimulator::ProcessAsyncCopiesInIdleTime(float time) { + if (time <= 0.0) { + return; + } + float remaining_simulation_time = time; + // This loop simulates the execution of the front memory requests in the + // read and/or write queues. The loop terminates when the remaining time is + // exhausted or there are no more outstanding async copies. + while ((!outstanding_read_default_queue_.empty() || + !outstanding_write_default_queue_.empty()) && + remaining_simulation_time > 0.0) { + float available_bandwidth = cost_analysis_->base_costs().BytesPerSecond(); + if (!outstanding_read_default_queue_.empty() && + !outstanding_write_default_queue_.empty()) { + // Need to share the bandwidth + available_bandwidth *= 0.5; + } + float bytes_to_process = available_bandwidth * remaining_simulation_time; + if (!outstanding_read_default_queue_.empty()) { + bytes_to_process = std::min( + bytes_to_process, + outstanding_read_default_queue_.front().remaining_bytes_to_transfer); + } + if (!outstanding_write_default_queue_.empty()) { + bytes_to_process = std::min( + bytes_to_process, + outstanding_write_default_queue_.front().remaining_bytes_to_transfer); + } + + float real_elapsed_time = bytes_to_process / available_bandwidth; + remaining_simulation_time -= real_elapsed_time; + RemoveBytesFromQueueIfNotEmpty(outstanding_read_default_queue_, + bytes_to_process); + RemoveBytesFromQueueIfNotEmpty(outstanding_write_default_queue_, + bytes_to_process); + } +} + } // namespace memory_space_assignment } // namespace xla diff --git a/xla/service/memory_space_assignment/simulator.h b/xla/service/memory_space_assignment/simulator.h index 146d569fbe66f0..900a2e0593c741 100644 --- a/xla/service/memory_space_assignment/simulator.h +++ b/xla/service/memory_space_assignment/simulator.h @@ -18,11 +18,14 @@ limitations under the License. #include #include +#include +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/service/memory_space_assignment/cost_analysis.h" +#include "xla/shape_util.h" namespace xla { namespace memory_space_assignment { @@ -97,20 +100,41 @@ class RuntimeSimulator { const std::list& GetOutstandingWriteDefaultQueue() const; + // This is an auxiliary function for simulating the execution + // time for a compute instruction. It returns the elapsed time (in seconds) + // for executing the compute instruction. + // + // Aside from returning the elapsed time, this function also updates the + // outstanding memory request queues, by draining them when the compute + // instruction is not occupying bandwidth. + float SimulateComputeInstruction( + const HloInstruction* compute_instruction, + absl::Span> + operands_in_alternate_memory, + absl::Span outputs_in_alternate_memory); + private: const CostAnalysis* cost_analysis_; CostAnalysis::Cache cost_analysis_cache_; - // Members used for memory model simulation - int64_t alternate_memory_space_; - std::list outstanding_read_default_queue_; - std::list outstanding_write_default_queue_; + // This function updates the queue by updating the front request with the // processed bytes. If the request is completed (no remaining bytes to // process), the function returns the instruction and pop it from the queue. // Otherwise, it returns nullptr. const HloInstruction* RemoveBytesFromQueueIfNotEmpty( std::list& async_copy_queue, float processed_bytes); + + // This is an auxiliary function which simulates the process of draining + // the memory access queues in a given amount of time (seconds). If both + // outstanding_*_default_queues are non-empty, they share bandwidth. If one of + // the queues is empty and the other is not, it gets the full bandwdith. + void ProcessAsyncCopiesInIdleTime(float time); + // Members used for memory model simulation + int64_t alternate_memory_space_; + std::list outstanding_read_default_queue_; + std::list outstanding_write_default_queue_; }; + } // namespace memory_space_assignment } // namespace xla #endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_SIMULATOR_H_ diff --git a/xla/service/memory_space_assignment/simulator_test.cc b/xla/service/memory_space_assignment/simulator_test.cc index 57f152af12a105..ad588538d7bf27 100644 --- a/xla/service/memory_space_assignment/simulator_test.cc +++ b/xla/service/memory_space_assignment/simulator_test.cc @@ -74,6 +74,8 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { std::make_unique( *hlo_cost_analysis_); CostAnalysisOptions _options; + // Assume 2 byte per second for testing. + _options.alternate_mem_bandwidth_bytes_per_second = 2; TF_ASSIGN_OR_RETURN( cost_analysis_, CostAnalysis::Create(*hlo_cost_analysis_costs_, _options, *module_)); @@ -126,12 +128,15 @@ TEST_F(MemorySpaceAssignmentSimulatorTest, SingleLayerNestedLoop) { // Since the HLO does not contain memory access, pass an empty allocation // sequence for test. memory_space_assignment::AllocationSequence allocations; - // The while loop has 42 iterations, and each iteration has 2 FLOP (for - // %increment and %greater). Thus, the total FLOPs are 84 FLOPs. - float expected_elapsed_time = 84; + // The total elapsed time is the summation of the elapsed time of each + // instruction. Here are the overhead of each instruction (secs): + // %increment: 12 * 42 + // tuple(%constant.0): 8 * 1 + // %greater: 9 * 42 + // %loop_result: 8 * 42 EXPECT_EQ(runtime_simulator_->ComputeEstimatedElapsedTime(*hlo_live_range, allocations), - expected_elapsed_time); + 1226); } class SimulateAsyncCopyDoneTest : public MemorySpaceAssignmentSimulatorTest { @@ -293,5 +298,128 @@ TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyTransferPartialProcess) { EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); } +TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesWithComputeInstruction) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[128] parameter(0) + param_1 = f32[32] parameter(1) + copy-start.1 = (f32[128]{0:S(1)}, f32[128], u32[]) copy-start(param_0) + neg = f32[32] negate(param_1) + ROOT copy-done.1 = f32[128]{0:S(1)} copy-done(copy-start.1) + } + )"; + + TF_ASSERT_OK(Initialize(hlo_string)); + const HloInstruction* copy_start_1_inst = instruction_map_["copy-start.1"]; + const HloInstruction* neg_inst = instruction_map_["neg"]; + + float compute_elapsed_time = runtime_simulator_->SimulateComputeInstruction( + neg_inst, /*operands_in_alternate_memory=*/{}, + /*outputs_in_alternate_memory=*/{}); + + // The compute operand requires 32 FLOPs and 32 * 4 * 2 bytes access, which + // requires 32 and 256 secs respectively. Thus, it is default memory access + // dominated, which does not have idle time to process the async copy. + EXPECT_EQ(compute_elapsed_time, 256); + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ + copy_start_1_inst, 512}})); + + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); +} + +TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesInTimeWithSharedBandwidth) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[128] parameter(0) + param_1 = f32[32]{0:S(1)} parameter(1) + copy-start.1 = (f32[128]{0:S(1)}, f32[128], u32[]) copy-start(param_0) + copy-start.2 = (f32[32], f32[32]{0:S(1)}, u32[]) copy-start(param_1) + neg = f32[32] negate(param_1) + copy-done.2 = f32[32] copy-done(copy-start.2) + ROOT copy-done.1 = f32[128]{0:S(1)} copy-done(copy-start.1) + } + )"; + + TF_ASSERT_OK(Initialize(hlo_string)); + + const HloInstruction* copy_start_1_inst = instruction_map_["copy-start.1"]; + const HloInstruction* copy_start_2_inst = instruction_map_["copy-start.2"]; + + // The instruction reads 32 * 4 bytes from alternate memory, which takes 64 + // secs. In this 64 secs, it does not access default memory. Thus, we can + // process the async copies in this time. Both queues are not empty, so the + // bandwidth is shared. Each of the request at the front of the queue process + // 64 sec * 0.5 bytes/sec = 32 bytes. + float compute_elapsed_time = runtime_simulator_->SimulateComputeInstruction( + instruction_map_["neg"], /*operands_in_alternate_memory=*/{{0, {}}}, + /*outputs_in_alternate_memory=*/{}); + // 64 secs for alternate memory access + 128 secs for default memory access + EXPECT_EQ(compute_elapsed_time, 192); + + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ + copy_start_1_inst, 480}})); + + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ + copy_start_2_inst, 96}})); +} + +TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesInTimeWithFullBandwidth) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[128] parameter(0) + param_1 = f32[32]{0:S(1)} parameter(1) + copy-start.1 = (f32[128]{0:S(1)}, f32[128], u32[]) copy-start(param_0) + neg = f32[32] negate(param_1) + ROOT copy-done.1 = f32[128]{0:S(1)} copy-done(copy-start.1) + } + )"; + + TF_ASSERT_OK(Initialize(hlo_string)); + + const HloInstruction* copy_start_1_inst = instruction_map_["copy-start.1"]; + + // Same as the 'ProcessAsyncCopiesInTimeWithSharedBandwidth' test, there are + // 64 secs idle time to process async copies. Since only the read queue is not + // empty, we can use the full bandwidth and process 64 sec * 1 bytes/sec = 64 + // bytes. + float compute_elapsed_time = runtime_simulator_->SimulateComputeInstruction( + instruction_map_["neg"], /*operands_in_alternate_memory=*/{{0, {}}}, + /*outputs_in_alternate_memory=*/{}); + // 64 secs for alternate memory access + 128 secs for default memory access + EXPECT_EQ(compute_elapsed_time, 192); + + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ + copy_start_1_inst, 448}})); + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); +} + +TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopyInTimeWithEmptyQueues) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[128] parameter(0) + ROOT neg = f32[128] negate(param_0) + } + )"; + + TF_ASSERT_OK(Initialize(hlo_string)); + + float compute_elapsed_time = runtime_simulator_->SimulateComputeInstruction( + instruction_map_["neg"], /*operands_in_alternate_memory=*/{}, + /*outputs_in_alternate_memory=*/{}); + // Execution time: 128 * 4 * 2 / 1 for default access + EXPECT_EQ(compute_elapsed_time, 1024); + // The queues should remain empty. + EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), IsEmpty()); + EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); +} + } // namespace } // namespace xla From a5a2c0d5288fc18f4095e61fa2299ce361b94d0e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 24 Jul 2024 12:27:58 -0700 Subject: [PATCH 122/376] Replace uses of gtl::linked_hash_{map/set} to absl::btree_{map/set} in third_party auto-sharding code with absl::btree_{map/set}. PiperOrigin-RevId: 655660614 --- xla/hlo/experimental/auto_sharding/BUILD | 1 + .../auto_sharding/auto_sharding.cc | 36 ++++++++-------- .../auto_sharding/auto_sharding.h | 16 ++++---- .../auto_sharding/auto_sharding_cost_graph.cc | 2 +- .../auto_sharding/auto_sharding_cost_graph.h | 9 ++-- .../auto_sharding_dot_handler.cc | 2 +- .../auto_sharding_solver_test.cc | 12 ++---- .../auto_sharding/auto_sharding_strategy.cc | 2 +- .../auto_sharding/auto_sharding_strategy.h | 41 +++++++++++++------ .../auto_sharding/auto_sharding_util.cc | 22 +++++----- .../auto_sharding/auto_sharding_util.h | 36 +++++++++------- .../auto_sharding/profiling_result.h | 9 ++-- 12 files changed, 102 insertions(+), 86 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/BUILD b/xla/hlo/experimental/auto_sharding/BUILD index 681016f0897004..258bc53fc2d2ca 100644 --- a/xla/hlo/experimental/auto_sharding/BUILD +++ b/xla/hlo/experimental/auto_sharding/BUILD @@ -152,6 +152,7 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_value", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index dae6bebb6be88c..b0f17e9ffbddbf 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -305,7 +305,7 @@ GenerateReshardingCostsAndShardingsForAllOperands( void FollowArrayOrTokenStrategyGroup( const StrategyGroup& src_strategy_group, const Shape& shape, const size_t instruction_id, const ClusterEnvironment& cluster_env, - const StableHashMap>& + const StableMap>& pretrimmed_strategy_map, StrategyGroup& strategy_group) { CHECK(shape.IsArray() || shape.IsToken()); @@ -427,7 +427,7 @@ std::unique_ptr MaybeFollowInsStrategyGroup( const StrategyGroup* src_strategy_group, const Shape& shape, const size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, - const StableHashMap>& + const StableMap>& pretrimmed_strategy_map) { std::unique_ptr strategy_group; if (src_strategy_group->is_tuple) { @@ -1477,8 +1477,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( const StrategyMap& strategy_map, const std::vector& instructions, const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, - StableHashMap>& - pretrimmed_strategy_map, + StableMap>& pretrimmed_strategy_map, const CallGraph& call_graph, const bool strict) { if (strategy_group->is_tuple) { for (size_t i = 0; i < strategy_group->childs.size(); ++i) { @@ -1760,7 +1759,7 @@ std::unique_ptr CreateElementwiseOperatorStrategies( const size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const InstructionDepthMap& depth_map, const AliasMap& alias_map, - const StableHashMap>& + const StableMap>& pretrimmed_strategy_map, const int64_t max_depth, StrategyGroups& strategy_groups, AssociativeDotPairs& associative_dot_pairs) { @@ -2975,14 +2974,12 @@ void FindReplicateSet( absl::Span s_val, const StrategyMap& strategy_map, const ShardingStrategy& strategy, const HloInstruction* output, const bool do_all_gather_after_backward, HloInstruction*& transpose_inst, - StableHashSet& replicated_set, - StableHashSet& boundary_set, - StableHashSet& consumer_set, - StableHashSet& visited) { + InstructionSet& replicated_set, InstructionSet& boundary_set, + InstructionSet& consumer_set, ConstInstructionSet& visited) { visited.insert(cur); // Check whether the node is a boundary node. - StableHashSet users = UsersWithAlias(cur, alias_map, output); + InstructionSet users = UsersWithAlias(cur, alias_map, output); for (HloInstruction* consumer : users) { const HloInstruction* shape_inst = cur; @@ -3066,7 +3063,7 @@ absl::Status GenerateReduceScatter( bool use_all_reduce_for_grad_acc = option.reduce_scatter_grad_acc_friendly; std::vector insert_all_gather; - StableHashSet modified; + ConstInstructionSet modified; for (HloInstruction* inst : instructions) { if (!HasReduceScatterOpportunity(inst, strategy_map, cost_graph, s_val, @@ -3079,10 +3076,10 @@ absl::Status GenerateReduceScatter( continue; } - StableHashSet replicated_set; - StableHashSet boundary_set; - StableHashSet consumer_set; - StableHashSet visited; + InstructionSet replicated_set; + InstructionSet boundary_set; + InstructionSet consumer_set; + ConstInstructionSet visited; // We allow at most one transpose in the path of replication analysis. HloInstruction* transpose_inst = nullptr; @@ -3591,10 +3588,11 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, } // Return whether an instruction has the opportunity to generate reduce-scatter. -bool HasReduceScatterOpportunity( - const HloInstruction* inst, const StrategyMap& strategy_map, - const CostGraph& cost_graph, absl::Span s_val, - const StableHashSet& modified) { +bool HasReduceScatterOpportunity(const HloInstruction* inst, + const StrategyMap& strategy_map, + const CostGraph& cost_graph, + absl::Span s_val, + const ConstInstructionSet& modified) { // If the operand is already modified by other ops, skip this instruction to // avoid conflicts. for (const HloInstruction* operand : inst->operands()) { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index 557483ec89d9ce..4695efc60d0dea 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -221,10 +221,11 @@ absl::Status GenerateReduceScatter( const CostGraph& cost_graph, absl::Span s_val, const ClusterEnvironment& cluster_env, const AutoShardingOption& option); -bool HasReduceScatterOpportunity( - const HloInstruction* inst, const StrategyMap& strategy_map, - const CostGraph& cost_graph, absl::Span s_val, - const StableHashSet& modified); +bool HasReduceScatterOpportunity(const HloInstruction* inst, + const StrategyMap& strategy_map, + const CostGraph& cost_graph, + absl::Span s_val, + const ConstInstructionSet& modified); HloSharding GetReduceScatterOutput(const HloInstruction* ins, const ShardingStrategy& strategy, @@ -285,7 +286,7 @@ std::unique_ptr CreateElementwiseOperatorStrategies( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const InstructionDepthMap& depth_map, const AliasMap& alias_map, - const StableHashMap>& + const StableMap>& pretrimmed_strategy_map, int64_t max_depth, StrategyGroups& strategy_groups, AssociativeDotPairs& associative_dot_pairs); @@ -362,7 +363,7 @@ std::unique_ptr MaybeFollowInsStrategyGroup( const StrategyGroup* src_strategy_group, const Shape& shape, size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, - const StableHashMap>& + const StableMap>& pretrimmed_strategy_map); void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( @@ -379,8 +380,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( const StrategyMap& strategy_map, const std::vector& instructions, const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, - StableHashMap>& - pretrimmed_strategy_map, + StableMap>& pretrimmed_strategy_map, const CallGraph& call_graph, bool strict); // Build possible sharding strategies and their costs for all instructions. diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc index 1156e0b80c3027..85127883e21937 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc @@ -38,7 +38,7 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, const AssociativeDotPairs& associative_dot_pairs) { node_lens_.reserve(strategy_groups.size()); extra_node_costs_.reserve(strategy_groups.size()); - adjacency_.assign(strategy_groups.size(), StableHashSet()); + adjacency_.assign(strategy_groups.size(), StableSet()); // Build the cost graph. for (StrategyGroup* strategy_group : strategy_groups) { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index 08b0bd968b6d4c..fda06ee8ec1e7b 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -104,23 +104,22 @@ class CostGraph { // The number of strategies of each node. std::vector node_lens_; // The adjacency list of each node. - std::vector> adjacency_; + std::vector> adjacency_; // The cost matrix between two nodes. - StableHashMap, EdgeReshardingCostMatrix> - edge_costs_; + StableMap, EdgeReshardingCostMatrix> edge_costs_; // The extra node costs introduced by merging nodes. std::vector> extra_node_costs_; // The reindexing vector of the node. // A reindexing vector maps a strategy index from the node being followed // to a strategy index of the current node. - StableHashMap> reindexing_vector_; + StableMap> reindexing_vector_; // Maps a node id to the node id that is being followed by this node. // The value is -1 if the current node does not follow any node. std::vector follow_idx_; // Save the destination of merged nodes. - StableHashMap merged_to_; + StableMap merged_to_; // Save pairs that need to be merged. std::vector> to_merge_pairs_; }; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 9f76195abd2607..36d03717e9ca00 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -55,7 +55,7 @@ namespace xla { namespace spmd { namespace { -using DimMap = StableHashMap; +using DimMap = StableMap; using MeshDims = absl::Span; struct Enumeration { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 16a62c0123c771..5a237b2f979a67 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -790,8 +790,8 @@ TEST(ScaleRequest, SkipsScaling) { EXPECT_THAT(request, ::testing::EqualsProto(expected_request)); } -TEST(StableHashMap, IterationOrderDeterminism){ - StableHashMap map; +TEST(StableMap, IterationOrderDeterminism){ + StableMap map; std::vector insertion_order = {6, 3, 1, 2, 4, 5, 10, 0, 7, 9, 8}; for (int key : insertion_order) { map[key] = key; @@ -801,12 +801,8 @@ TEST(StableHashMap, IterationOrderDeterminism){ for (const auto& [key, value] : map) { iteration_order.push_back(key); } - if (tsl::kIsOpenSource) { - EXPECT_THAT(iteration_order, - ::testing::ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); - } else { - EXPECT_EQ(iteration_order, insertion_order); - } + EXPECT_THAT(iteration_order, + ::testing::ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); } TEST(ValidateRequest, AcceptsAutoShardingSolverRequest) { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index ae5cc8cd58bf08..6c4ae8251033b9 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -142,7 +142,7 @@ BuildStrategyAndCost( // is useful when the operand is forced to use a user sharding, and the op // doesn't need to strictly follow it. We restore the trimmed strategies in // this situation. - StableHashMap> pretrimmed_strategy_map; + StableMap> pretrimmed_strategy_map; StrategyGroups strategy_groups; AssociativeDotPairs associative_dot_pairs; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 697355d7562c96..d9ae855a098a3b 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -24,6 +24,8 @@ limitations under the License. #include #include +#include "absl/container/btree_map.h" +#include "absl/container/btree_set.h" #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -31,8 +33,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/hlo_value.h" #include "xla/shape_util.h" -#include "absl/container/btree_map.h" -#include "absl/container/flat_hash_set.h" namespace xla { namespace spmd { @@ -42,20 +42,36 @@ constexpr double kInfinityCost = 1e20; // Type alias template -using StableHashMap = ::absl::btree_map; +using StableMap = absl::btree_map; template -using StableHashSet = ::absl::flat_hash_set; +using StableSet = absl::btree_set; + +struct CompareHloInstruction { + bool operator()(const HloInstruction* a, const HloInstruction* b) const { + return a->name() < b->name(); + } +}; + +template +using ConstInstructionMap = + absl::btree_map; +template +using InstructionMap = + absl::btree_map; + +using ConstInstructionSet = + absl::btree_set; +using InstructionSet = absl::btree_set; // Map an instruction to its depth. -using InstructionDepthMap = StableHashMap; +using InstructionDepthMap = ConstInstructionMap; // Map an instruction to its batch dimension. -using InstructionBatchDimMap = StableHashMap; +using InstructionBatchDimMap = StableMap; // Map an instruction to its alias source parameter. -using AliasMap = StableHashMap; +using AliasMap = ConstInstructionMap; // Map an instruction to its resharding cache. using ReshardingCache = - StableHashMap>>; + ConstInstructionMap>>; // Resharding costs for each operand using ReshardingCosts = std::vector>; @@ -149,7 +165,7 @@ using AliasIdx = int64_t; // An index into the alias vector. // Various classes needed to support strategy shaving. using NodeStrategy = std::pair; -using NodeStrategies = StableHashSet; +using NodeStrategies = StableSet; // A group of strategy choices (along with details like index values) // for each instruction. @@ -226,8 +242,7 @@ using LivenessNodeSet = std::vector>; // A liveness set using edge indices instead of HLO values. using LivenessEdgeSet = std::vector>; // Map an instruction to its strategy group. -using StrategyMap = - StableHashMap>; +using StrategyMap = ConstInstructionMap>; // The list of all strategy groups. using StrategyGroups = std::vector; // The list of all dot instruction pairs that can be optimized by @@ -235,7 +250,7 @@ using StrategyGroups = std::vector; using AssociativeDotPairs = std::vector>; // The set of all alias pairs -using AliasSet = StableHashSet>; +using AliasSet = StableSet>; } // namespace spmd } // namespace xla diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 92f93f0b5f0bae..4b86f967ab7da2 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -184,7 +184,7 @@ InstructionDepthMap BuildInstructionDepthMap( const std::vector& instructions = sequence.instructions(); InstructionDepthMap depth_map; - StableHashMap degree_dict; + ConstInstructionMap degree_dict; // Init frontier size_t collected = 0; @@ -929,7 +929,7 @@ bool IsDivisible(const HloInstruction* ins, const Array& device_mesh, void SetSharding(HloInstruction* to_split, const HloSharding& output_spec, const HloInstruction* ref_inst, const HloInstruction* shape_inst, - StableHashSet& modified) { + ConstInstructionSet& modified) { modified.insert(to_split); if (DimensionsEqual(to_split->shape(), ref_inst->shape())) { to_split->set_sharding(output_spec); @@ -955,16 +955,16 @@ bool IsAlwaysReplicated(const HloInstruction* inst) { } // Try to reduce the boundary set to its common ancestor -void TryReduceWithCommonAncestor(StableHashSet& replicated_set, - StableHashSet& boundary_set, - StableHashSet& consumer_set, +void TryReduceWithCommonAncestor(InstructionSet& replicated_set, + InstructionSet& boundary_set, + InstructionSet& consumer_set, const AliasMap& alias_map) { if (boundary_set.size() != 2) { return; } HloInstruction* ancestor = nullptr; - StableHashSet path; + InstructionSet path; for (HloInstruction* node : boundary_set) { HloInstruction* cur = node; while (cur->operand_count() == 1) { @@ -999,7 +999,7 @@ void TryReduceWithCommonAncestor(StableHashSet& replicated_set, consumer_set.insert(ancestor); } -void UseAllReduceForGradAcc(StableHashSet& replicated_set, +void UseAllReduceForGradAcc(InstructionSet& replicated_set, const HloInstruction* inst) { if (inst->users().size() != 1) { return; @@ -1007,7 +1007,7 @@ void UseAllReduceForGradAcc(StableHashSet& replicated_set, // Find the add instruction for grad accumulation, skip the identity marker // for remat and other elementwise ops. - const HloInstruction* add = + HloInstruction* add = PassThroughCustomCallMarkerUser(inst->users().front(), inst); if (add->opcode() == HloOpcode::kGetTupleElement || add->opcode() == HloOpcode::kTranspose) { @@ -1029,9 +1029,9 @@ void UseAllReduceForGradAcc(StableHashSet& replicated_set, // Do not partition the dot, add and parameter, so we can generate // all-reduce for grad accumulation. - std::function dfs_remove; - dfs_remove = [&](const HloInstruction* cur) { - if (!replicated_set.contains(cur)) { + std::function dfs_remove; + dfs_remove = [&](HloInstruction* cur) { + if (replicated_set.count(cur) == 0) { return; } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 5e0e6a4191fefa..a4ea23c922fc06 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -133,7 +133,7 @@ std::string ToString(const std::vector& vector) { } template -std::string ToString(const StableHashMap& map) { +std::string ToString(const StableMap& map) { std::string result; for (const auto& [k, v] : map) { result = absl::StrCat(result, " [", k, "->", v, "]"); @@ -284,23 +284,31 @@ inline HloInstruction* PassThroughCustomCallMarkerUser( // Return the users of an instruction and its alias, // excluding the final output tuple. -inline StableHashSet UsersWithAlias( - const HloInstruction* inst, const AliasMap& alias_map, - const HloInstruction* output) { - StableHashSet users; - +inline InstructionSet UsersWithAlias(const HloInstruction* inst, + const AliasMap& alias_map, + const HloInstruction* output) { + InstructionSet users; for (HloInstruction* user : inst->users()) { - users.insert(PassThroughCustomCallMarkerUser(user, inst)); + HloInstruction* pass_through_user = + PassThroughCustomCallMarkerUser(user, inst); + if (pass_through_user == output) { + continue; + } + users.insert(pass_through_user); } auto iter = alias_map.find(inst); if (iter != alias_map.end()) { for (HloInstruction* user : iter->second->users()) { - users.insert(PassThroughCustomCallMarkerUser(user, iter->second)); + HloInstruction* pass_through_user = + PassThroughCustomCallMarkerUser(user, iter->second); + if (pass_through_user == output) { + continue; + } + users.insert(pass_through_user); } } - users.erase(output); return users; } @@ -312,21 +320,21 @@ bool IsParameterConvert(const HloInstruction* inst); bool IsAlwaysReplicated(const HloInstruction* inst); // Try to reduce the boundary set to its common ancestor -void TryReduceWithCommonAncestor(StableHashSet& replicated_set, - StableHashSet& boundary_set, - StableHashSet& consumer_set, +void TryReduceWithCommonAncestor(InstructionSet& replicated_set, + InstructionSet& boundary_set, + InstructionSet& consumer_set, const AliasMap& alias_map); // Return whether all users of an instruction is reduce. bool AllUsersAreReduce(const HloInstruction* inst); -void UseAllReduceForGradAcc(StableHashSet& replicated_set, +void UseAllReduceForGradAcc(InstructionSet& replicated_set, const HloInstruction* inst); void SetSharding(HloInstruction* to_split, const HloSharding& output_spec, const HloInstruction* ref_inst, const HloInstruction* shape_inst, - StableHashSet& modified); + ConstInstructionSet& modified); template inline std::vector Argsort(const std::vector& scores) { diff --git a/xla/hlo/experimental/auto_sharding/profiling_result.h b/xla/hlo/experimental/auto_sharding/profiling_result.h index 873aabf786388d..1fe95c47ae5941 100644 --- a/xla/hlo/experimental/auto_sharding/profiling_result.h +++ b/xla/hlo/experimental/auto_sharding/profiling_result.h @@ -104,8 +104,7 @@ class ProfilingResult { // Estimate the cost by linear interpolation between the two closest points. double EstimateInternal( const std::vector>& replica_groups, int64_t size, - const std::string& dtype, - const StableHashMap& cost_dict) const { + const std::string& dtype, const StableMap& cost_dict) const { Key key(Group2Str(replica_groups), dtype); Value cost_list = cost_dict.at(key); @@ -147,9 +146,9 @@ class ProfilingResult { } bool enabled_; - StableHashMap all_reduce_cost_dict_; - StableHashMap all_gather_cost_dict_; - StableHashMap reduce_scatter_cost_dict_; + StableMap all_reduce_cost_dict_; + StableMap all_gather_cost_dict_; + StableMap reduce_scatter_cost_dict_; }; } // namespace spmd From 14e053e7bf2003492db783512dc69ab51454e9a6 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Wed, 24 Jul 2024 13:36:17 -0700 Subject: [PATCH 123/376] [IFRT] Include diagnostic info in the error message when parsing IFRT IR module fails. PiperOrigin-RevId: 655683822 --- xla/python/ifrt/ir/tests/BUILD | 3 ++ .../ir/tests/executable_impl_test_base.cc | 31 +++++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/xla/python/ifrt/ir/tests/BUILD b/xla/python/ifrt/ir/tests/BUILD index f68d8b09142df1..b872c2baa16c35 100644 --- a/xla/python/ifrt/ir/tests/BUILD +++ b/xla/python/ifrt/ir/tests/BUILD @@ -60,6 +60,7 @@ cc_library( visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla:status_macros", + "//xla/mlir/utils:error_util", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/python/ifrt", "//xla/python/ifrt:test_util", @@ -67,8 +68,10 @@ cc_library( "//xla/python/ifrt/ir:sharding_param", "//xla/python/ifrt/ir/transforms:built_in_spmd_expansions", "//xla/tsl/concurrency:ref_count", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", diff --git a/xla/python/ifrt/ir/tests/executable_impl_test_base.cc b/xla/python/ifrt/ir/tests/executable_impl_test_base.cc index 993e4d4479834f..71f4c62f061f95 100644 --- a/xla/python/ifrt/ir/tests/executable_impl_test_base.cc +++ b/xla/python/ifrt/ir/tests/executable_impl_test_base.cc @@ -15,18 +15,35 @@ limitations under the License. #include "xla/python/ifrt/ir/tests/executable_impl_test_base.h" +#include +#include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/InitAllDialects.h" #include "mlir/Parser/Parser.h" +#include "xla/mlir/utils/error_util.h" #include "xla/mlir_hlo/mhlo/IR/register.h" +#include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/ir/ifrt_dialect.h" #include "xla/python/ifrt/ir/sharding_param.h" #include "xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/test_util.h" #include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" #include "tsl/platform/statusor.h" namespace xla { @@ -50,16 +67,26 @@ void IfrtIrExecutableImplTestBase::SetUp() { absl::StatusOr> IfrtIrExecutableImplTestBase::LoadFromSource(absl::string_view source) { + mlir::BaseScopedDiagnosticHandler diagnostic_handler(&mlir_context_); auto op_ref = mlir::parseSourceString(source, &mlir_context_); - TF_RET_CHECK(op_ref) << "Failed to parse MLIR source"; + if (!op_ref) { + return absl::InvalidArgumentError( + absl::StrFormat("Failed to parse IFRT IR module string: %s", + diagnostic_handler.ConsumeStatus().message())); + } return op_ref; } absl::StatusOr> IfrtIrExecutableImplTestBase::LoadFromFile(absl::string_view file_path) { + mlir::BaseScopedDiagnosticHandler diagnostic_handler(&mlir_context_); auto op_ref = mlir::parseSourceFile(file_path, &mlir_context_); - TF_RET_CHECK(op_ref) << "Failed to parse MLIR file"; + if (!op_ref) { + return absl::InvalidArgumentError( + absl::StrFormat("Failed to parse IFRT IR module file: %s", + diagnostic_handler.ConsumeStatus().message())); + } return op_ref; } From a4667a93c263a9c6d741909a0644d3fbb0512aff Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Wed, 24 Jul 2024 13:38:55 -0700 Subject: [PATCH 124/376] [XLA:GPU] Clean up post-scheduling passes. Create a single pipeline and move relevant pipelines into subpipelines. PiperOrigin-RevId: 655684699 --- xla/service/gpu/gpu_compiler.cc | 160 +++++++++++++++++++------------- 1 file changed, 94 insertions(+), 66 deletions(-) diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 7780e9b315a1b5..38751937725929 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -2354,13 +2354,81 @@ absl::Status GpuCompiler::RunPreSchedulingPasses( return pipeline.Run(module).status(); } +HloCostAnalysis::Options CreateHloAnalysisOpts( + const HloModule& module, const se::DeviceDescription& gpu_device_info, + ShapeSizeFn shape_size_fn) { + HloCostAnalysis::Options hlo_cost_analysis_options; + hlo_cost_analysis_options.shape_size = shape_size_fn; + std::optional + offloading_config = std::nullopt; + if (module.config().debug_options().xla_gpu_enable_host_memory_offloading()) { + constexpr float kGiga = 1e+9; + // Fused multiply-add means that these two instructions are computed as + // one, so for this case the maximum flops is doubled. + constexpr float kFma = 2; + float flops_per_sec = gpu_device_info.core_count() * + gpu_device_info.fpus_per_core() * + gpu_device_info.clock_rate_ghz() * kGiga * kFma; + int64_t host_memory_space_color = + static_cast(se::MemoryType::kHost); + hlo_cost_analysis_options.set_flops_per_second(flops_per_sec); + hlo_cost_analysis_options.set_transcendentals_per_second(flops_per_sec); + offloading_config = + std::make_optional( + /*host_memory_space=*/host_memory_space_color, + /*bandwidth_to_host_bytes_per_second=*/ + gpu_device_info.memory_bandwidth(), + /*bandwidth_from_host_bytes_per_second=*/ + gpu_device_info.memory_bandwidth()); + } + return hlo_cost_analysis_options; +} + +HloRematerialization::Options CreateRematOpts( + const HloModule& module, const se::DeviceDescription& gpu_device_info, + HloCostAnalysis& hlo_cost_analysis, int64_t scheduler_mem_limit) { + bool enable_offloading = + module.config().debug_options().xla_gpu_enable_host_memory_offloading(); + std::optional + offloading_config = std::nullopt; + if (enable_offloading) { + int64_t host_memory_space_color = + static_cast(se::MemoryType::kHost); + offloading_config = + std::make_optional( + /*host_memory_space=*/host_memory_space_color, + /*bandwidth_to_host_bytes_per_second=*/ + gpu_device_info.memory_bandwidth(), + /*bandwidth_from_host_bytes_per_second=*/ + gpu_device_info.memory_bandwidth()); + } + HloRematerialization::RematerializationModeConfig + rematerialization_mode_config(/*recompute=*/true, /*compress=*/true, + /*host_offload=*/enable_offloading); + HloRematerialization::Options options( + hlo_cost_analysis, rematerialization_mode_config, + // Assume 75% of the total device memory is available for XLA. + /*memory_limit_bytes=*/scheduler_mem_limit, + /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, + /*min_remat_size=*/0, /*compact_shape_function=*/nullptr, + /*host_memory_offload_config=*/offloading_config); + return options; +} + absl::Status GpuCompiler::RunPostSchedulingPipelines( HloModule* module, int64_t scheduler_mem_limit, const se::DeviceDescription& gpu_device_info) const { TF_RETURN_IF_ERROR( RunPostSchedulingCopyInsertion(module, GetCanShareBuffer())); + HloPassPipeline main_pipeline("post-scheduling-passes"); + + // Pipeline for async -> sync conversion on for non-overlapped async ops. + HloPredicate is_nop = + HloPredicateIsOp; { - HloPassPipeline pipeline("post-scheduling-passes"); + HloPassPipeline& pipeline = + main_pipeline.AddPass("async-to-sync-converter"); if (module->config() .debug_options() @@ -2368,90 +2436,50 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( module->config().debug_options().xla_gpu_enable_pipelined_p2p()) { pipeline.AddPass(); } - HloPredicate is_nop = - HloPredicateIsOp; pipeline.AddPass(is_nop); - - TF_RETURN_IF_ERROR(pipeline.Run(module).status()); } + // Pipeline rematerialization passes with optional host offloading. + HloRematerialization::RematerializationSizes sizes; + // `HloCostAnalysis` initialization. + HloCostAnalysis::Options hlo_cost_analysis_opts = + CreateHloAnalysisOpts(*module, gpu_device_info, ShapeSizeBytesFunction()); + HloCostAnalysis hlo_cost_analysis(hlo_cost_analysis_opts); + // `HloRematerialization` options initialization. + HloRematerialization::Options remat_opts = CreateRematOpts( + *module, gpu_device_info, hlo_cost_analysis, scheduler_mem_limit); { - HloPassPipeline pipeline("remat-pipeline"); - - const bool enable_offloading = module->config() - .debug_options() - .xla_gpu_enable_host_memory_offloading(); - HloRematerialization::RematerializationModeConfig - rematerialization_mode_config(/*recompute=*/true, /*compress=*/true, - /*host_offload=*/enable_offloading); - HloCostAnalysis::Options hlo_cost_analysis_options; - hlo_cost_analysis_options.shape_size = ShapeSizeBytesFunction(); - std::optional - offloading_config = std::nullopt; - if (enable_offloading) { - constexpr float kGiga = 1e+9; - // Fused multiply-add means that these two instructions are computed as - // one, so for this case the maximum flops is doubled. - constexpr float kFma = 2; - float flops_per_sec = gpu_device_info.core_count() * - gpu_device_info.fpus_per_core() * - gpu_device_info.clock_rate_ghz() * kGiga * kFma; - int64_t host_memory_space_color = - static_cast(se::MemoryType::kHost); - hlo_cost_analysis_options.set_flops_per_second(flops_per_sec); - hlo_cost_analysis_options.set_transcendentals_per_second(flops_per_sec); - offloading_config = - std::make_optional( - /*host_memory_space=*/host_memory_space_color, - /*bandwidth_to_host_bytes_per_second=*/ - gpu_device_info.memory_bandwidth(), - /*bandwidth_from_host_bytes_per_second=*/ - gpu_device_info.memory_bandwidth()); - } - HloCostAnalysis hlo_cost_analysis(hlo_cost_analysis_options); - HloRematerialization::Options options( - hlo_cost_analysis, rematerialization_mode_config, - // Assume 75% of the total device memory is available for XLA. - /*memory_limit_bytes=*/scheduler_mem_limit, - /*block_size_limit=*/1, /*block_rematerialization_factor=*/1, - /*min_remat_size=*/0, /*compact_shape_function=*/nullptr, - /*host_memory_offload_config=*/offloading_config); - HloRematerialization::RematerializationSizes sizes; - pipeline.AddPass(options, sizes); + HloPassPipeline& pipeline = + main_pipeline.AddPass("remat-pipeline"); + + pipeline.AddPass(remat_opts, sizes); pipeline.AddPass(); pipeline.AddPass(); - - TF_ASSIGN_OR_RETURN(bool changed, pipeline.Run(module)); - if (changed) { - VLOG(1) << "HloRematerialization saved " - << sizes.before_bytes - sizes.after_bytes << " bytes"; - } } + // Wrap remaining unfused ops that have no LHLO equivalent in single-op + // fusions. This needs to happen after rematerialization, because that + // will insert additional copies. { - HloPassPipeline pipeline("fusion-wrapper"); + HloPassPipeline& pipeline = + main_pipeline.AddPass("fusion-wrapper"); pipeline.AddPass(); - // Wrap remaining unfused ops that have no LHLO equivalent in single-op - // fusions. This needs to happen after rematerialization, because that - // will insert additional copies. - TF_RETURN_IF_ERROR(pipeline.Run(module).status()); } - // After we have a scheduled module and all operations wrapped into fusions - // we can decide how to wrap them into command buffers. + // Pipeline with passes which wrap a scheduled module into command buffers. { - HloPassPipeline pipeline("command-buffer-scheduling"); - auto driver_version = se::gpu::GpuDriver::GetDriverVersion(); - const int32_t toolkit_version = GetToolkitVersion(); + absl::StatusOr driver_version = + se::gpu::GpuDriver::GetDriverVersion(); + int32_t toolkit_version = GetToolkitVersion(); + HloPassPipeline& pipeline = + main_pipeline.AddPass("command-buffer-scheduling"); pipeline.AddPass( gpu_device_info, toolkit_version, driver_version.value_or(toolkit_version)); pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(module).status()); } - return absl::OkStatus(); + return main_pipeline.Run(module).status(); } absl::Status GpuCompiler::LoadAutotuneResultsFromFile( From 7ec84c52b0d8b80ba4bd48467ce56ee5f4322d0e Mon Sep 17 00:00:00 2001 From: zoranjovanovic-ns <126815388+zoranjovanovic-ns@users.noreply.github.com> Date: Wed, 24 Jul 2024 13:44:04 -0700 Subject: [PATCH 125/376] PR #15267: [ROCm] Fixed compilation issues. Imported from GitHub PR https://github.com/openxla/xla/pull/15267 Issues were caused by following commits: https://github.com/openxla/xla/commit/d17181b49de71b0fb0ff6236745d43d630c39401 https://github.com/openxla/xla/commit/429da0c5cca821b75a75f95e1c4256f883b0bae5 Copybara import of the project: -- 66fbafe91986ebb02b8236175f27d7ebf8a23989 by Zoran Jovanovic : [ROCm] Fixed compilation issues. Merging this change closes #15267 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15267 from ROCm:ci_hotfix_20240724 66fbafe91986ebb02b8236175f27d7ebf8a23989 PiperOrigin-RevId: 655686460 --- xla/stream_executor/gpu/BUILD | 4 ++-- xla/stream_executor/rocm/rocm_executor.cc | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index c65157f3117123..8e918262abfc55 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -576,8 +576,8 @@ xla_test( cc_library( name = "gpu_cudamallocasync_allocator", - srcs = ["gpu_cudamallocasync_allocator.cc"], - hdrs = ["gpu_cudamallocasync_allocator.h"], + srcs = if_cuda_is_configured(["gpu_cudamallocasync_allocator.cc"]), + hdrs = if_cuda_is_configured(["gpu_cudamallocasync_allocator.h"]), tags = ["gpu"], deps = [ ":gpu_init_impl", diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 3ebe531e1ba556..49fc9c646868ae 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -190,8 +190,7 @@ GpuExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true)); TF_RETURN_IF_ERROR(start_event->Record(stream->gpu_stream())); return std::make_unique(gpu_context(), std::move(start_event), - std::move(stop_event), stream, - std::move(semaphore)); + std::move(stop_event), stream); } bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { From 3088a8b34c94a24f98b425a8f7ef180f8b7bcc01 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 24 Jul 2024 15:35:24 -0700 Subject: [PATCH 126/376] Make both ROCm and CUDA GpuContexts operate in terms of device_ordinals rather than using "next_id_" sometimes. PiperOrigin-RevId: 655723012 --- xla/stream_executor/cuda/cuda_driver.cc | 42 +++++++------------------ xla/stream_executor/cuda/cuda_driver.h | 10 +++--- xla/stream_executor/rocm/rocm_driver.cc | 1 - xla/stream_executor/rocm/rocm_driver.h | 7 ++--- 4 files changed, 19 insertions(+), 41 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index cfcf01533239ce..e498468f5e6753 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -77,15 +77,10 @@ limitations under the License. } \ } while (0) -// Debugging: on each push and pop of a cuda context, verify the current context -// matches the expected one. -constexpr bool kVerifyGpuContext = false; - namespace stream_executor { namespace gpu { /* static */ absl::Mutex CreatedContexts::mu_{absl::kConstInit}; -/* static */ int64_t CreatedContexts::next_id_ = 1; // 0 means "no context" namespace { @@ -119,15 +114,9 @@ tsl::thread::ThreadPool* GetDriverExecutor() { namespace { -// Call cuCtxtSynchronize and crash if it doesn't succeed. -void SynchronizeOrDie() { - FAIL_IF_CUDA_RES_ERROR(cuCtxSynchronize(), - "Synchronize fail: ", tsl::CurrentStackTrace()); -} - thread_local struct ThreadLocalData { - int64_t id; - GpuContext* context; // Only valid if id == a known good context. + GpuContext* context; + int device_ordinal; int depth; } tls_data = {}; @@ -140,57 +129,48 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) { // been left in the same state we left it. Other code may have run on this // thread and altered the context. if (tls->depth == 0) { - VLOG(3) << "ScopedActivateContext switching to " << cuda_context->id(); + VLOG(3) << "ScopedActivateContext switching to " + << cuda_context->device_ordinal(); FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()), "Failed setting context"); tls->depth = 1; - tls->id = cuda_context->id(); + tls->device_ordinal = cuda_context->device_ordinal(); tls->context = cuda_context; to_restore_ = nullptr; return; } tls->depth++; - if (tls->id == cuda_context->id()) { - if (kVerifyGpuContext) { - CHECK_EQ(CurrentContext(), cuda_context->context()); - } + if (tls->device_ordinal == cuda_context->device_ordinal()) { DCHECK_EQ(CurrentContext(), cuda_context->context()); return; } - VLOG(3) << "ScopedActivateContext switching context from " << tls->id - << " to " << cuda_context->id(); + VLOG(3) << "ScopedActivateContext switching context from " + << tls->device_ordinal << " to " << cuda_context->device_ordinal(); to_restore_ = tls->context; // Set the context and update thread local. FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()), "Failed setting context"); - tls->id = cuda_context->id(); + tls->device_ordinal = cuda_context->device_ordinal(); tls->context = cuda_context; } ScopedActivateContext::~ScopedActivateContext() { auto* tls = &tls_data; - if (kVerifyGpuContext) { - // Note that if kVerifyGpuContext is used, and contexts are deleted, it's - // possible this could fail in the CurrentContext() call. - CHECK_EQ(CurrentContext(), - tls->context == nullptr ? nullptr : tls->context->context()); - } - tls->depth--; DCHECK_GE(tls->depth, 0); if (to_restore_ == nullptr) { - // Leave context, tls->id, and tls->context set. + // Leave context, tls->device_ordinal, and tls->context set. return; } // Set context and update thread local. FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(to_restore_->context()), "Failed setting context"); - tls->id = to_restore_->id(); + tls->device_ordinal = to_restore_->device_ordinal(); tls->context = to_restore_; } diff --git a/xla/stream_executor/cuda/cuda_driver.h b/xla/stream_executor/cuda/cuda_driver.h index 1a0df21a8349b1..96b3428e2c9a94 100644 --- a/xla/stream_executor/cuda/cuda_driver.h +++ b/xla/stream_executor/cuda/cuda_driver.h @@ -57,10 +57,11 @@ absl::StatusOr QueryEvent(GpuContext* context, CUevent event); // unique id is positive, and ids are not repeated within the process. class GpuContext { public: - GpuContext(CUcontext context, int64_t id) : context_(context), id_(id) {} + GpuContext(CUcontext context, int device_ordinal) + : context_(context), device_ordinal_(device_ordinal) {} CUcontext context() const { return context_; } - int64_t id() const { return id_; } + int device_ordinal() const { return device_ordinal_; } // Disallow copying and moving. GpuContext(GpuContext&&) = delete; @@ -70,7 +71,7 @@ class GpuContext { private: CUcontext const context_; - const int64_t id_; + const int device_ordinal_; }; // Manages the singleton map of contexts that we've created, mapping @@ -98,7 +99,7 @@ class CreatedContexts { auto it = insert_result.first; if (insert_result.second) { // context was not present in the map. Add it. - it->second = std::make_unique(context, next_id_++); + it->second = std::make_unique(context, device_ordinal); (*LiveOrdinal())[device_ordinal].push_back(context); } return it->second.get(); @@ -161,7 +162,6 @@ class CreatedContexts { // Lock that guards access-to/mutation-of the live set. static absl::Mutex mu_; - static int64_t next_id_; }; } // namespace gpu diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index 4d040a7f6f49f5..bb8982fc3d654c 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -75,7 +75,6 @@ namespace stream_executor { namespace gpu { /* static */ absl::Mutex CreatedContexts::mu_{absl::kConstInit}; -/* static */ int64_t CreatedContexts::next_id_ = 1; // 0 means "no context" // Formats hipError_t to output prettified values into a log stream. // Error summaries taken from: diff --git a/xla/stream_executor/rocm/rocm_driver.h b/xla/stream_executor/rocm/rocm_driver.h index 6668bd9c157f39..1a6c98761888de 100644 --- a/xla/stream_executor/rocm/rocm_driver.h +++ b/xla/stream_executor/rocm/rocm_driver.h @@ -37,8 +37,8 @@ absl::StatusOr QueryEvent(GpuContext* context, hipEvent_t event); // GpuContext wraps the device_ordinal and hipCtx_t handle. class GpuContext { public: - GpuContext(hipCtx_t context, const int v) - : context_(context), device_ordinal_(v) {} + GpuContext(hipCtx_t context, const int ordinal) + : context_(context), device_ordinal_(ordinal) {} hipCtx_t context() const { return context_; } int device_ordinal() const { return device_ordinal_; } @@ -79,7 +79,7 @@ class CreatedContexts { auto it = insert_result.first; if (insert_result.second) { // context was not present in the map. Add it. - it->second = std::make_unique(context, next_id_++); + it->second = std::make_unique(context, device_ordinal); (*LiveOrdinal())[device_ordinal].push_back(context); } return it->second.get(); @@ -136,7 +136,6 @@ class CreatedContexts { // Lock that guards access-to/mutation-of the live set. static absl::Mutex mu_; - static int64_t next_id_; }; } // namespace gpu From d9bc4f5030dc85eadb7ce5d90ee4f86d5d5f4375 Mon Sep 17 00:00:00 2001 From: Vladyslav Tsilytskyi Date: Wed, 24 Jul 2024 16:11:25 -0700 Subject: [PATCH 127/376] [xla:cpu] Refactor EmitTargetElementLoop PiperOrigin-RevId: 655733953 --- xla/service/cpu/ir_emitter.cc | 46 ++++++++++++++++++++--------------- xla/service/cpu/ir_emitter.h | 14 +++++------ 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 8b4c3a204eaf52..75b1f3ccce2254 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -24,8 +24,9 @@ limitations under the License. #include #include #include -#include +#include #include +#include #include #include #include @@ -808,7 +809,8 @@ absl::Status IrEmitter::HandleSelectAndScatter( [this, init_value](const llvm_ir::IrArray::Index& target_index) { llvm::Value* init_value_addr = GetEmittedValueFor(init_value); return Load(IrShapeType(init_value->shape()), init_value_addr); - })); + }, + std::optional(output_array))); // Create a loop to iterate over the source array to scatter to the output. llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), b()); @@ -1566,8 +1568,7 @@ absl::Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { /*input_buffer=*/input_buffer, /*output_buffer=*/output_buffer, /*source_target_pairs=*/source_target_pairs_v, - /*source_target_pairs_size=*/ - b()->getInt32(source_target_pairs.size())}, + /*source_target_pairs_size=*/b()->getInt32(source_target_pairs.size())}, b()->getVoidTy()); return absl::OkStatus(); @@ -2318,7 +2319,8 @@ absl::Status IrEmitter::HandlePad(HloInstruction* pad) { const HloInstruction* padding_value = pad->operand(1); llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); return Load(IrShapeType(padding_value->shape()), padding_value_addr); - })); + }, + std::nullopt)); // Create a loop to iterate over the operand elements and update the output // locations where the operand elements should be stored. @@ -2376,7 +2378,8 @@ absl::Status IrEmitter::HandleFusion(HloInstruction* fusion) { BindFusionArguments(fusion, &fused_emitter); TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator( *fusion->fused_expression_root())); - return EmitTargetElementLoop(fusion, generator); + return EmitTargetElementLoop(fusion, "kLoop_fusion", generator, + std::nullopt); } else if (fusion->IsOutputFusion()) { VLOG(3) << "HandleFusion kOutput"; int64_t dot_op_index = @@ -4099,19 +4102,20 @@ absl::Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { } absl::Status IrEmitter::EmitTargetElementLoop( - HloInstruction* target_op, - const llvm_ir::ElementGenerator& element_generator) { - return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator); -} - -absl::Status IrEmitter::EmitTargetElementLoop( - HloInstruction* target_op, absl::string_view desc, - const llvm_ir::ElementGenerator& element_generator) { + const HloInstruction* target_op, absl::string_view desc, + const llvm_ir::ElementGenerator& element_generator, + std::optional result_array_opt) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); + llvm_ir::IrArray target_array; + if (result_array_opt.has_value()) { + target_array = result_array_opt.value(); + } else { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); + target_array = GetIrArrayFor(target_op); + } + const Shape& target_shape = target_op->shape(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); - llvm_ir::IrArray target_array = GetIrArrayFor(target_op); if (target_shape.IsTuple() && (target_op->opcode() == HloOpcode::kFusion || @@ -4131,7 +4135,7 @@ absl::Status IrEmitter::EmitTargetElementLoop( } TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, output_arrays, b()) - .EmitLoop(IrName(target_op))); + .EmitLoop(IrName(target_op, desc))); std::vector tuple_operand_ptrs; for (int64_t i = 0; i < output_arrays.size(); ++i) { @@ -4147,11 +4151,11 @@ absl::Status IrEmitter::EmitTargetElementLoop( // Emit parallel loop with dynamic loop bounds for most-major dimensions. TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, &dynamic_loop_bounds, b()) - .EmitLoop(IrName(target_op))); + .EmitLoop(IrName(target_op, desc))); } else { TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_array, b()) - .EmitLoop(IrName(target_op))); + .EmitLoop(IrName(target_op, desc))); } } return absl::OkStatus(); @@ -4196,7 +4200,9 @@ absl::Status IrEmitter::DefaultAction(HloInstruction* hlo) { } CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); return EmitTargetElementLoop( - hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); + hlo, "elemental_loop", + elemental_emitter.MakeElementGenerator(hlo, operand_to_generator), + std::nullopt); } llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall( diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index 4566a7ccdf2f35..45a4c5e22af1f3 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -18,9 +18,11 @@ limitations under the License. #include +#include #include #include #include +#include #include #include #include @@ -33,8 +35,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/TargetParser/Triple.h" @@ -432,14 +436,10 @@ class IrEmitter : public DfsHloVisitorWithDefault, // desc is an optional human-readable string that's added to the loop name in // IR. Regardless of whether desc is provided, target_op->name() is included // in the loop name. - // - // TODO(jingyue): target_op should be a `const HloInstruction*`. - absl::Status EmitTargetElementLoop( - HloInstruction* target_op, - const llvm_ir::ElementGenerator& element_generator); absl::Status EmitTargetElementLoop( - HloInstruction* target_op, absl::string_view desc, - const llvm_ir::ElementGenerator& element_generator); + const HloInstruction* target_op, absl::string_view desc, + const llvm_ir::ElementGenerator& element_generator, + std::optional result_array_opt); // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the From 4af9150c364f6a05b660af0701835db6b0b8b05d Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 24 Jul 2024 16:50:14 -0700 Subject: [PATCH 128/376] [IFRT] Harden XLA executable compilation in xla_executable_impl_test_lib xla_executable_impl_test_lib `CompileOnDevices` missed a few steps to compile an XLA computation in a portable way. This changes fixes them: * num_replicas, num_partitions, use_spmd_partitioning are set correctly. * device_assignment is set correctly for SPMD. * device_assignment uses device IDs correctly instead of device indices; there is no guarante that device ids and indices match. PiperOrigin-RevId: 655744675 --- xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc b/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc index fb148d9e803917..04da5007591f4f 100644 --- a/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc +++ b/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc @@ -80,18 +80,24 @@ absl::StatusOr> CompileOnDevices( } else { build_options.set_device_ordinal(devices.front()->Id().value()); if (replicated) { + build_options.set_num_replicas(devices.size()); + build_options.set_num_partitions(1); + build_options.set_use_spmd_partitioning(false); DeviceAssignment device_assignment(/*replica_count=*/devices.size(), /*computation_count=*/1); for (int i = 0; i < devices.size(); ++i) { - device_assignment(i, 0) = i; + device_assignment(i, 0) = devices[i]->Id().value(); } build_options.set_device_assignment(device_assignment); } else { + build_options.set_num_replicas(1); + build_options.set_num_partitions(devices.size()); + build_options.set_use_spmd_partitioning(true); DeviceAssignment device_assignment( /*replica_count=*/1, /*computation_count=*/devices.size()); for (int i = 0; i < devices.size(); ++i) { - device_assignment(i, 0) = i; + device_assignment(0, i) = devices[i]->Id().value(); } build_options.set_device_assignment(device_assignment); } From 7fdafc1fe89b35c5f8489a5ea7135a45ad6b2928 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 24 Jul 2024 18:06:25 -0700 Subject: [PATCH 129/376] [XLA:CPU] Turn off thunks runtime test for conv_depthwise_test. PiperOrigin-RevId: 655763373 --- xla/tests/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 048c47594dbc3f..46873be7c9e2fc 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -424,7 +424,6 @@ xla_test( "conv_depthwise_test.cc", ], shard_count = 50, - tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":conv_depthwise_common", From 108e40763ef74ad8afd7a7342b35904f256242f5 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 24 Jul 2024 19:24:09 -0700 Subject: [PATCH 130/376] [tsl] Remove dependency on platform:types from platform:mutex PiperOrigin-RevId: 655781413 --- third_party/tsl/tsl/platform/default/mutex.cc | 4 +++- third_party/tsl/tsl/platform/mutex.h | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/tsl/tsl/platform/default/mutex.cc b/third_party/tsl/tsl/platform/default/mutex.cc index d2f669f7bbd2cc..6f21fb64578e3a 100644 --- a/third_party/tsl/tsl/platform/default/mutex.cc +++ b/third_party/tsl/tsl/platform/default/mutex.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include + #include "nsync_cv.h" // NOLINT #include "nsync_mu.h" // NOLINT #include "nsync_mu_wait.h" // NOLINT @@ -72,7 +74,7 @@ void mutex::Await(const Condition &cond) { nsync::nsync_mu_wait(mu_cast(&mu_), &EvaluateCondition, &cond, nullptr); } -bool mutex::AwaitWithDeadline(const Condition &cond, uint64 abs_deadline_ns) { +bool mutex::AwaitWithDeadline(const Condition &cond, uint64_t abs_deadline_ns) { time_t seconds = abs_deadline_ns / (1000 * 1000 * 1000); nsync::nsync_time abs_time = nsync::nsync_time_s_ns( seconds, abs_deadline_ns - seconds * (1000 * 1000 * 1000)); diff --git a/third_party/tsl/tsl/platform/mutex.h b/third_party/tsl/tsl/platform/mutex.h index 0576f94741b0ea..c5effd9e9ef641 100644 --- a/third_party/tsl/tsl/platform/mutex.h +++ b/third_party/tsl/tsl/platform/mutex.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_MUTEX_H_ #define TENSORFLOW_TSL_PLATFORM_MUTEX_H_ -#include // NOLINT +#include // NOLINT +#include // NOLINT // for std::try_to_lock_t and std::cv_status #include // NOLINT #include // NOLINT #include "tsl/platform/platform.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/platform/types.h" // Include appropriate platform-dependent implementation details of mutex etc. #if defined(PLATFORM_GOOGLE) @@ -107,7 +107,7 @@ class TF_LOCKABLE mutex { // has been reached, then atomically reacquire *this in the same mode in // which it was previously held, and return whether cond.Eval() is true. // See tsl/tsl/platform/env_time.h for the time interface. - bool AwaitWithDeadline(const Condition& cond, uint64 abs_deadline_ns); + bool AwaitWithDeadline(const Condition& cond, uint64_t abs_deadline_ns); // ------- private: From 3fe3d5326708ca3e8fb5c169ae1eb3254aa85c82 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 24 Jul 2024 19:39:35 -0700 Subject: [PATCH 131/376] [xla:cpu] Fix tsan error in Thunk::ExecuteState PiperOrigin-RevId: 655785106 --- xla/service/cpu/runtime/BUILD | 1 + xla/service/cpu/runtime/thunk.cc | 15 +++++++++++---- xla/service/cpu/runtime/thunk.h | 3 ++- xla/service/cpu/runtime/thunk_test.cc | 1 + 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/xla/service/cpu/runtime/BUILD b/xla/service/cpu/runtime/BUILD index 1787141d0874ee..f34570a81b7517 100644 --- a/xla/service/cpu/runtime/BUILD +++ b/xla/service/cpu/runtime/BUILD @@ -80,6 +80,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status:statusor", diff --git a/xla/service/cpu/runtime/thunk.cc b/xla/service/cpu/runtime/thunk.cc index 7810401df56900..455c940e264f3b 100644 --- a/xla/service/cpu/runtime/thunk.cc +++ b/xla/service/cpu/runtime/thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "xla/executable_run_options.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_executable_run_options.h" @@ -158,13 +159,19 @@ tsl::AsyncValueRef Thunk::OkExecuteEvent() { return event->AsRef(); } -Thunk::ExecuteState::ExecuteState(int64_t parallel_tasks) - : pending_tasks(parallel_tasks), +Thunk::ExecuteState::ExecuteState(int64_t num_tasks) + : pending_tasks(num_tasks), event(tsl::MakeConstructedAsyncValueRef()) {} +Thunk::ExecuteState::~ExecuteState() { + auto cnt = pending_tasks.load(std::memory_order_acquire); + DCHECK_EQ(cnt, 0) + << "ExecuteState is destroyed before all tasks are completed"; +} + void Thunk::ExecuteState::Notify() { - if (pending_tasks.load(std::memory_order_relaxed) == 1 || - pending_tasks.fetch_sub(1, std::memory_order_relaxed) == 1) { + bool is_done = pending_tasks.fetch_sub(1, std::memory_order_acq_rel) == 1; + if (ABSL_PREDICT_FALSE(is_done)) { event.SetStateConcrete(); } } diff --git a/xla/service/cpu/runtime/thunk.h b/xla/service/cpu/runtime/thunk.h index 59d2106edf8625..210d19937b2173 100644 --- a/xla/service/cpu/runtime/thunk.h +++ b/xla/service/cpu/runtime/thunk.h @@ -304,7 +304,8 @@ class Thunk { // multiple tasks and need to signal completion when all tasks are done (see // ConvolutionThunk and DotThunk for examples). struct ExecuteState { - explicit ExecuteState(int64_t parallel_tasks); + explicit ExecuteState(int64_t num_tasks); + ~ExecuteState(); void Notify(); diff --git a/xla/service/cpu/runtime/thunk_test.cc b/xla/service/cpu/runtime/thunk_test.cc index b761c509a31373..510d2c2f44025a 100644 --- a/xla/service/cpu/runtime/thunk_test.cc +++ b/xla/service/cpu/runtime/thunk_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/cpu/runtime/thunk.h" #include +#include #include "xla/executable_run_options.h" #include "xla/service/cpu/collectives_interface.h" From f98bb99bdf8a6c7cbd7363d263e2c838083beea9 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 24 Jul 2024 21:41:12 -0700 Subject: [PATCH 132/376] [xla:cpu] Do not test while loop temp aliasing with thunks runtime PiperOrigin-RevId: 655815997 --- xla/service/llvm_ir/BUILD | 2 ++ xla/service/llvm_ir/alias_analysis_test.cc | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/xla/service/llvm_ir/BUILD b/xla/service/llvm_ir/BUILD index 8dc0694712eb4a..c39abe9e8f5adc 100644 --- a/xla/service/llvm_ir/BUILD +++ b/xla/service/llvm_ir/BUILD @@ -54,9 +54,11 @@ xla_cc_test( name = "alias_analysis_test", srcs = ["alias_analysis_test.cc"], deps = [ + "//xla:xla_proto_cc", "//xla/ffi", "//xla/ffi:ffi_api", "//xla/service/cpu/tests:cpu_codegen_test", + "//xla/tests:hlo_test_base", "@com_google_absl//absl/status", "@tsl//tsl/platform:test", ], diff --git a/xla/service/llvm_ir/alias_analysis_test.cc b/xla/service/llvm_ir/alias_analysis_test.cc index 3547bae2109aa6..cb91226ecc4bac 100644 --- a/xla/service/llvm_ir/alias_analysis_test.cc +++ b/xla/service/llvm_ir/alias_analysis_test.cc @@ -17,12 +17,22 @@ limitations under the License. #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" #include "tsl/platform/test.h" -namespace xla { -namespace cpu { +namespace xla::cpu { namespace { -class AliasAnalysisTest : public CpuCodegenTest {}; + +class AliasAnalysisTest : public CpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + // We do not generate IR for while loops with thunks runtime, so we + // explicitly disable it for this test. + debug_options.set_xla_cpu_use_thunk_runtime(false); + return debug_options; + } +}; static absl::Status FakeCustomCallTarget(ffi::AnyBuffer, ffi::Result) { @@ -86,5 +96,4 @@ ENTRY while3 { } } // namespace -} // namespace cpu -} // namespace xla +} // namespace xla::cpu From 4175224f3457a82043a1eb143420cb67d8833c30 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 24 Jul 2024 21:43:17 -0700 Subject: [PATCH 133/376] [xla:cpu] Embed LLVM IR into executable when running thunks PiperOrigin-RevId: 655816447 --- xla/service/cpu/cpu_compiler.cc | 9 +++++++++ xla/tools/hlo_opt/cpu_llvm.hlo | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 1454fbc93cdc37..e19b90d7c8f96e 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -1250,6 +1250,11 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { TF_ASSIGN_OR_RETURN(ThunkSequence thunks, thunk_emitter.EmitEntryComputation(*module)); + std::string ir_module_string; + if (embed_ir_in_executable) { + ir_module_string = llvm_ir::DumpToString(llvm_module.get()); + } + // JIT compile the LLVM IR module to in-memory machine code. TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); cantFail((*jit)->AddModule(llvm::orc::ThreadSafeModule( @@ -1289,6 +1294,10 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { // Save object files to be able to export them to AOT compilation result. cpu_executable->set_obj_files(std::move(obj_files)); + if (embed_ir_in_executable) { + cpu_executable->set_ir_module_string(ir_module_string); + } + return with_hlo_proto(std::move(cpu_executable)); } diff --git a/xla/tools/hlo_opt/cpu_llvm.hlo b/xla/tools/hlo_opt/cpu_llvm.hlo index fbb033e53b07eb..ea0cd33be13342 100644 --- a/xla/tools/hlo_opt/cpu_llvm.hlo +++ b/xla/tools/hlo_opt/cpu_llvm.hlo @@ -8,7 +8,7 @@ add { ROOT out = s8[] add(a, b) } -// CHECK: i8 +// CHECK: reduce ENTRY e { p1 = s8[1048576] parameter(0) i = s8[] constant(0) From 801909e81a12303b83e5b928e4ce8f5eb6680302 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 25 Jul 2024 00:42:40 -0700 Subject: [PATCH 134/376] [xla:cpu] Move oneDNN tests to the CPU folder since they are specific to CPU. + Add missing header include and build dependencies. + Temporarily disable onednn_matmul_test. PiperOrigin-RevId: 655856660 --- xla/service/cpu/tests/BUILD | 83 +++++++++++++++++ .../cpu}/tests/onednn_convolution_test.cc | 0 .../cpu}/tests/onednn_layer_norm_test.cc | 0 .../cpu}/tests/onednn_matmul_test.cc | 0 .../cpu}/tests/onednn_softmax_test.cc | 1 + xla/tests/BUILD | 89 +------------------ 6 files changed, 86 insertions(+), 87 deletions(-) rename xla/{ => service/cpu}/tests/onednn_convolution_test.cc (100%) rename xla/{ => service/cpu}/tests/onednn_layer_norm_test.cc (100%) rename xla/{ => service/cpu}/tests/onednn_matmul_test.cc (100%) rename xla/{ => service/cpu}/tests/onednn_softmax_test.cc (99%) diff --git a/xla/service/cpu/tests/BUILD b/xla/service/cpu/tests/BUILD index 7ddaaec2a004a6..a0921e4344ea03 100644 --- a/xla/service/cpu/tests/BUILD +++ b/xla/service/cpu/tests/BUILD @@ -356,3 +356,86 @@ xla_cc_test( "@tsl//tsl/platform:test_main", ], ) + +xla_cc_test( + name = "onednn_matmul_test", + srcs = ["onednn_matmul_test.cc"], + copts = tsl_copts(), + tags = [ + "no_oss", + "notap", + ], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:cpu_plugin", + "//xla/service/cpu:onednn_matmul_rewriter", + "//xla/service/cpu:onednn_util", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@tsl//tsl/platform:platform_port", + ], +) + +xla_cc_test( + name = "onednn_convolution_test", + srcs = ["onednn_convolution_test.cc"], + copts = tsl_copts(), + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:cpu_plugin", + "//xla/service/cpu:onednn_matmul_rewriter", + "//xla/service/cpu:onednn_util", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@tsl//tsl/platform:platform_port", + ], +) + +xla_cc_test( + name = "onednn_layer_norm_test", + srcs = ["onednn_layer_norm_test.cc"], + copts = tsl_copts(), + deps = [ + "//xla:test", + "//xla/service:cpu_plugin", + "//xla/service/cpu:onednn_util", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + ], +) + +xla_cc_test( + name = "onednn_softmax_test", + srcs = ["onednn_softmax_test.cc"], + copts = tsl_copts(), + shard_count = 4, + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla/service:cpu_plugin", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/cpu:backend_config_proto_cc", + "//xla/service/cpu:onednn_config_proto_cc", + "//xla/service/cpu:onednn_ops_rewriter", + "//xla/service/cpu:onednn_util", + "//xla/tests:hlo_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/xla/tests/onednn_convolution_test.cc b/xla/service/cpu/tests/onednn_convolution_test.cc similarity index 100% rename from xla/tests/onednn_convolution_test.cc rename to xla/service/cpu/tests/onednn_convolution_test.cc diff --git a/xla/tests/onednn_layer_norm_test.cc b/xla/service/cpu/tests/onednn_layer_norm_test.cc similarity index 100% rename from xla/tests/onednn_layer_norm_test.cc rename to xla/service/cpu/tests/onednn_layer_norm_test.cc diff --git a/xla/tests/onednn_matmul_test.cc b/xla/service/cpu/tests/onednn_matmul_test.cc similarity index 100% rename from xla/tests/onednn_matmul_test.cc rename to xla/service/cpu/tests/onednn_matmul_test.cc diff --git a/xla/tests/onednn_softmax_test.cc b/xla/service/cpu/tests/onednn_softmax_test.cc similarity index 99% rename from xla/tests/onednn_softmax_test.cc rename to xla/service/cpu/tests/onednn_softmax_test.cc index 8c28f5d005dad5..124b4472024c17 100644 --- a/xla/tests/onednn_softmax_test.cc +++ b/xla/service/cpu/tests/onednn_softmax_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/substitute.h" #include "xla/literal.h" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/service/cpu/onednn_ops_rewriter.h" #include "xla/service/cpu/onednn_util.h" #include "xla/service/pattern_matcher.h" diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 46873be7c9e2fc..094f3d1ddf4e9e 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -14,7 +14,7 @@ load( load("//xla:package_groups.bzl", "xla_tests_package_groups") load("//xla:xla.bzl", "tests_build_defs_bzl_deps", "xla_cc_binary", "xla_cc_test") load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") -load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts") +load("//xla/tsl:tsl.bzl", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") package( @@ -424,6 +424,7 @@ xla_test( "conv_depthwise_test.cc", ], shard_count = 50, + tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", ":conv_depthwise_common", @@ -3159,92 +3160,6 @@ xla_cc_test( ], ) -xla_test( - name = "onednn_matmul_test", - srcs = ["onednn_matmul_test.cc"], - backends = [ - "cpu", - ], - copts = tsl_copts(), - tags = ["no_oss"], - deps = [ - ":hlo_test_base", - ":test_macros_header", - ":xla_internal_test_main", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/utils:hlo_matchers", - "//xla/service/cpu:onednn_util", - "@tsl//tsl/platform:platform_port", - ], -) - -xla_test( - name = "onednn_convolution_test", - srcs = ["onednn_convolution_test.cc"], - backends = [ - "cpu", - ], - copts = tsl_copts(), - deps = [ - ":hlo_test_base", - ":test_macros_header", - ":xla_internal_test_main", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/utils:hlo_matchers", - "@tsl//tsl/platform:platform_port", - ], -) - -xla_test( - name = "onednn_layer_norm_test", - srcs = ["onednn_layer_norm_test.cc"], - backends = [ - "cpu", - ], - copts = tsl_copts(), - deps = [ - ":hlo_test_base", - ":test_macros_header", - ":xla_internal_test_main", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - ], -) - -xla_test( - name = "onednn_softmax_test", - srcs = ["onednn_softmax_test.cc"], - backends = [ - "cpu", - ], - copts = tsl_copts(), - shard_count = 4, - deps = [ - ":hlo_test_base", - ":test_macros_header", - ":xla_internal_test_main", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service/cpu:backend_config_proto_cc", - "//xla/service/cpu:onednn_ops_rewriter", - "//xla/service/cpu:onednn_util", - "@com_google_absl//absl/strings", - "@tsl//tsl/platform:platform_port", - ], -) - xla_test( name = "numerics_test", srcs = ["numerics_test.cc"], From c8b38fbf47e437ec50fc9443dfbf45c67b41c4bf Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Thu, 25 Jul 2024 01:00:21 -0700 Subject: [PATCH 135/376] [XLA:GPU] Return flops cost of common instruction directly. Calling `GpuHloCostAnalysis` to get FLOPs is too expensive. For instructions that only do indexing, the cost is always 0. Elementwise instruction can be extracted from hlo profile data. For more complicated and rare instructions, still fall back to GpuHloCostAnalysis. PiperOrigin-RevId: 655859942 --- xla/service/gpu/model/BUILD | 1 + .../model/gpu_indexing_performance_model.cc | 70 +++++++--- .../model/gpu_indexing_performance_model.h | 11 +- .../gpu_indexing_performance_model_test.cc | 121 +++++++++++++++++- 4 files changed, 184 insertions(+), 19 deletions(-) diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index 0a884ae230160c..57b619801fbf5c 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -385,6 +385,7 @@ xla_cc_test( ":gpu_indexing_performance_model", ":gpu_performance_model_base", "//xla:shape_util", + "//xla:test_helpers", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", diff --git a/xla/service/gpu/model/gpu_indexing_performance_model.cc b/xla/service/gpu/model/gpu_indexing_performance_model.cc index 1d955486f731c0..49b914eb19cc17 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -55,25 +55,63 @@ namespace xla { namespace gpu { int64_t GpuPerformanceModelWithIndexingAnalysis::FlopsPerElement( - const HloInstruction* instr) const { - // TODO(shyshkov): Replace dependency on GpuHloCostAnalysis with independent - // flops calculation. - GpuHloCostAnalysis::Options cost_analysis_options{ - shape_size_, - /*per_second_rates=*/{}, - /*count_multiple_input_accesses=*/true}; - GpuHloCostAnalysis cost_analysis(cost_analysis_options, *device_info_); - TF_CHECK_OK( - cost_analysis.RevisitInstruction(const_cast(instr))); + const HloInstruction* instr) { + // Instruction that are only used for indexing are not counted for FLOPs. + switch (instr->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kConstant: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kGather: + case HloOpcode::kIota: + case HloOpcode::kPad: + case HloOpcode::kParameter: + case HloOpcode::kSlice: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return 0; + default: + break; + }; + + // Get the FLOPs per element for elementwise operations that only depend on + // the element type. + if (instr->IsElementwise()) { + return cost_analysis_.GetFlopsPerElementwiseOpElement( + instr->shape().element_type(), instr->opcode()); + } - int64_t num_elements = [&] { - if (instr->opcode() == HloOpcode::kReduce && instr->shape().IsTuple()) { - return ShapeUtil::ElementsInRecursive(instr->shape().tuple_shapes(0)); + if (instr->opcode() == HloOpcode::kReduce) { + int64_t flops_per_reduce_computation = 0; + for (const HloInstruction* reducer_instr : + instr->called_computations()[0]->instructions()) { + flops_per_reduce_computation += FlopsPerElement(reducer_instr); } - return ShapeUtil::ElementsInRecursive(instr->shape()); - }(); - return cost_analysis.flop_count(*instr) / num_elements; + auto operand_shape = instr->operand(0)->shape(); + auto output_shape = instr->shape().IsArray() + ? instr->shape() + : instr->shape().tuple_shapes(0); + + // Size of reduction dimensions. + int64_t reduction_factor = ShapeUtil::ElementsIn(operand_shape) / + ShapeUtil::ElementsIn(output_shape); + + // The Cost Model assumes that the reduction computation is applied N-1 + // times to reduce N elements. This is not true, because emitters will + // generate a loop with N iterations. We don't fix it here to keep this + // estimate consistent with GpuHloCostAnalysis. This is like doesn't matter + // much for the application of the Cost Model. + return (reduction_factor - 1) * flops_per_reduce_computation; + } + + // Encountered unexpected instruction, call to GpuHloCostAnalysis. + TF_CHECK_OK( + cost_analysis_.RevisitInstruction(const_cast(instr))); + + return cost_analysis_.flop_count(*instr) / + ShapeUtil::ElementsInRecursive(instr->shape()); } int64_t GpuPerformanceModelWithIndexingAnalysis::GetShapeSizeRecursive( diff --git a/xla/service/gpu/model/gpu_indexing_performance_model.h b/xla/service/gpu/model/gpu_indexing_performance_model.h index 149a64c9908456..daff68ef161339 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/fusion_analysis_cache.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/hlo_op_profiles.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" @@ -62,6 +63,11 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { device_info_(device_info), fusion_analysis_cache_(fusion_analysis_cache), shape_size_(shape_size), + cost_analysis_( + GpuHloCostAnalysis::Options{shape_size_, + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}, + *device_info_), mlir_context_(mlir_context) {} EstimateRunTimeData EstimateRunTimeForFusion( @@ -110,17 +116,18 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { absl::StatusOr TryFindBestTilingForFusion( const HloFusionAdaptor& fusion_adaptor); - private: // Returns an estimate how many FLOPs will be used to produce one element of // the output. - int64_t FlopsPerElement(const HloInstruction* instr) const; + int64_t FlopsPerElement(const HloInstruction* instr); + private: int64_t GetShapeSizeRecursive(const Shape& shape) const; const HloOpProfiles::HloOpProfile* hlo_op_profile_; const se::DeviceDescription* device_info_; HloFusionAnalysisCache* fusion_analysis_cache_; HloCostAnalysis::ShapeSizeFunction shape_size_; + GpuHloCostAnalysis cost_analysis_; mlir::MLIRContext* mlir_context_; }; diff --git a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index 7dee26d4d4d396..a278af8487ccaa 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -50,6 +51,7 @@ using ::testing::HasSubstr; using ::tsl::testing::StatusIs; class GpuIndexingPerformanceModelTest : public HloTestBase { + public: GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { return [&](const Shape& shape) { constexpr int64_t kPointerSize = 8; @@ -57,7 +59,6 @@ class GpuIndexingPerformanceModelTest : public HloTestBase { }; } - public: mlir::MLIRContext mlir_context_; // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. @@ -414,6 +415,124 @@ ENTRY main { HasSubstr("SymbolicTileAnalysis failed"))); } +class FlopsPerElementTest : public GpuIndexingPerformanceModelTest { + public: + void CompareFlopsModels(absl::string_view hlo_module_string) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_module_string)); + + GpuHloCostAnalysis cost_analysis( + GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(), + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}, + device_info_); + + ASSERT_IS_OK(module->entry_computation()->Accept(&cost_analysis)); + auto instr = module->entry_computation()->root_instruction(); + + int64_t flops_per_element = indexing_cost_model_.FlopsPerElement(instr); + const Shape& output_shape = instr->shape().IsArray() + ? instr->shape() + : instr->shape().tuple_shapes(0); + int64_t total_flops = + ShapeUtil::ElementsIn(output_shape) * flops_per_element; + + EXPECT_EQ(total_flops, cost_analysis.flop_count(*instr)); + } +}; + +TEST_F(FlopsPerElementTest, MatchesGpuHloCostAnalysis_Reduce) { + CompareFlopsModels(R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add.0 = f32[] add(param_0, param_1) +} + +ENTRY entry_computation { + param_0.3 = f32[32,40] parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[32] reduce(param_0.3, constant), dimensions={1}, to_apply=add +} +)"); +} + +TEST_F(FlopsPerElementTest, MatchesGpuHloCostAnalysis_VariadicReduce) { + CompareFlopsModels(R"( +HloModule m + +add_multiply { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + param_2 = f32[] parameter(2) + param_3 = f32[] parameter(3) + add = f32[] add(param_0, param_2) + multiply = f32[] multiply(param_1, param_3) + ROOT t = (f32[], f32[]) tuple(add, multiply) +} + +ENTRY entry_computation { + param_0 = f32[32,40] parameter(0) + c0 = f32[] constant(0) + ROOT reduce = (f32[32], f32[32]) reduce(param_0, param_0, c0, c0), dimensions={1}, to_apply=add_multiply +} +)"); +} + +TEST_F(FlopsPerElementTest, MatchesGpuHloCostAnalysis_Elementwise_Cosine) { + CompareFlopsModels(R"( +HloModule m + +ENTRY entry_computation { + param_0 = f32[32] parameter(0) + ROOT cosine = f32[32] cosine(param_0) +} +)"); +} + +TEST_F(FlopsPerElementTest, MatchesGpuHloCostAnalysis_Elementwise_Clamp) { + CompareFlopsModels(R"( +HloModule m + +ENTRY entry_computation { + param_0 = f32[32] parameter(0) + param_1 = f32[32] parameter(1) + param_2 = f32[32] parameter(2) + ROOT clamp = clamp(param_0, param_1, param_2) +} +)"); +} + +TEST_F(FlopsPerElementTest, MatchesGpuHloCostAnalysis_Gather) { + CompareFlopsModels(R"( +HloModule module +entry { + operand = f32[33, 76, 70] parameter(0) + indices = s32[1806, 2] parameter(1) + ROOT gather = f32[1806, 7, 8, 4] gather(operand, indices), + offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0,1}, + index_vector_dim=1, slice_sizes={7,8,4} +})"); +} + +TEST_F(FlopsPerElementTest, MatchesGpuHloCostAnalysis_ReduceWindow) { + CompareFlopsModels(R"( + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry { + param_0 = f32[13,12,8,15] parameter(0) + c0 = f32[] constant(0) + ROOT reduce-window = f32[13,3,8,15] reduce-window(param_0, c0), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=add +})"); +} + } // namespace } // namespace gpu } // namespace xla From d5cec3f187a3b0f36554d4bc00dbb88b94ce3fd7 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Thu, 25 Jul 2024 01:07:51 -0700 Subject: [PATCH 136/376] PR #15285: Ensure only one device is visible in pjrt_c_api_gpu_test Imported from GitHub PR https://github.com/openxla/xla/pull/15285 The test fails when the number of available devices is more than 1. This patch fixes that by ensuring that only one device is visible to the test. Copybara import of the project: -- 587bebe70c7d298008eff0c65dfcfa901e1fe21a by Shraiysh Vaishay : Ensure only one device is visible in pjrt_c_api_gpu_test The test fails when the number of available devices is more than 1. This patch fixes that by ensuring that only one device is visible to the test. Merging this change closes #15285 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15285 from shraiysh:fix_pjrt_c_api_gpu_test_gpu 587bebe70c7d298008eff0c65dfcfa901e1fe21a PiperOrigin-RevId: 655861635 --- xla/pjrt/c/BUILD | 1 + xla/pjrt/c/pjrt_c_api_gpu_test.cc | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD index f4a72d3d97b5f1..dad4ceb9887634 100644 --- a/xla/pjrt/c/BUILD +++ b/xla/pjrt/c/BUILD @@ -376,6 +376,7 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla:shape_util", + "//xla/client:client_library", "//xla/ffi:execution_context", "//xla/ffi:ffi_api", "//xla/ffi:type_id_registry", diff --git a/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 0874a90d0845fa..eee88adae5a78d 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/client/client_library.h" #include "xla/ffi/api/ffi.h" #include "xla/ffi/execution_context.h" #include "xla/ffi/ffi_api.h" @@ -215,6 +216,7 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { auto kv_store = std::make_shared(); std::shared_ptr<::pjrt::PJRT_KeyValueCallbackData> kv_callback_data = ::pjrt::ConvertToCKeyValueCallbacks(kv_store); + xla::ClientLibrary::DestroyLocalInstances(); int num_nodes = 2; std::vector threads; @@ -225,7 +227,8 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) { kv_store = kv_store] { absl::flat_hash_map options = { {"num_nodes", static_cast(num_nodes)}, - {"node_id", static_cast(i)}}; + {"node_id", static_cast(i)}, + {"visible_devices", std::vector({0})}}; TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, ::pjrt::ConvertToPjRtNamedValueList(options)); TF_ASSERT_OK_AND_ASSIGN( From a705146ca736400c2c50a94b2e2e49936ab0dada Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Thu, 25 Jul 2024 02:00:37 -0700 Subject: [PATCH 137/376] Fill in parsers & printers for indexing_map attribute PiperOrigin-RevId: 655873427 --- xla/service/gpu/fusions/mlir/ir/BUILD | 1 + .../gpu/fusions/mlir/ir/xla_gpu_attrs.cc | 145 ++++++++++++++++-- .../gpu/fusions/mlir/ir/xla_gpu_attrs.h | 27 +--- .../gpu/fusions/mlir/ir/xla_gpu_attrs.td | 14 +- .../fusions/mlir/tests/indexing_map_attr.mlir | 118 ++++++++++++++ xla/service/gpu/model/BUILD | 1 + xla/service/gpu/model/indexing_map.h | 6 + 7 files changed, 260 insertions(+), 52 deletions(-) create mode 100644 xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir diff --git a/xla/service/gpu/fusions/mlir/ir/BUILD b/xla/service/gpu/fusions/mlir/ir/BUILD index e3db7145981a16..ba19b9b81b5a4e 100644 --- a/xla/service/gpu/fusions/mlir/ir/BUILD +++ b/xla/service/gpu/fusions/mlir/ir/BUILD @@ -94,6 +94,7 @@ cc_library( ":xla_gpu_dialect_inc_gen", ":xla_gpu_ops_inc_gen", "//xla/service/gpu/model:indexing_analysis", + "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BytecodeOpInterface", diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc index 6feb0eadfaade9..2f51b2572831f7 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc @@ -15,39 +15,152 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" +#include #include +#include "absl/strings/str_format.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "xla/service/gpu/model/indexing_map.h" +#define GET_ATTRDEF_LIST +#define GET_ATTRDEF_CLASSES +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc" + namespace xla { namespace gpu { -void PrintDimVars(mlir::AsmPrinter& p, llvm::ArrayRef dim_vars) {} +using llvm::ParseResult; +using llvm::SmallVector; +using mlir::AffineExpr; +using mlir::ArrayRef; +using mlir::AsmParser; +using mlir::AsmPrinter; +using mlir::failed; +using mlir::failure; + +ParseResult ParseInterval(AsmParser& parser, Interval& interval) { + // ParseResult converts to `true` if parsing failed. + return failure(parser.parseLSquare() || parser.parseInteger(interval.lower) || + parser.parseComma() || parser.parseInteger(interval.upper) || + parser.parseRSquare()); +} + +void PrintDimVars(AsmPrinter& p, ArrayRef dim_vars) { + for (int i = 0; i < dim_vars.size(); ++i) { + p << "d" << i << " in " << dim_vars[i].bounds << "\n"; + } +} + +mlir::FailureOr> ParseDimVars( + AsmParser& parser, ArrayRef dim_names) { + SmallVector dim_vars; + for (const auto& dim_name : dim_names) { + if (parser.parseKeyword(dim_name) || parser.parseKeyword("in") || + ParseInterval(parser, dim_vars.emplace_back().bounds)) { + return failure(); + } + } + return dim_vars; +} -mlir::FailureOr> ParseDimVars( - mlir::AsmParser& parser) { - return mlir::failure(); +void PrintRangeVars(AsmPrinter& p, ArrayRef range_vars) { + for (int i = 0; i < range_vars.size(); ++i) { + p << "s" << i << " in " << range_vars[i].range << "\n"; + } } -void PrintRangeVars(mlir::AsmPrinter& p, llvm::ArrayRef range_vars) {} +mlir::FailureOr> ParseRangeVars( + AsmParser& parser, ArrayRef range_symbol_names) { + SmallVector range_vars; + for (const auto& range_symbol_name : range_symbol_names) { + if (parser.parseKeyword(range_symbol_name) || parser.parseKeyword("in") || + ParseInterval(parser, range_vars.emplace_back().range)) { + return failure(); + } + } + return range_vars; +} -mlir::FailureOr> ParseRangeVars( - mlir::AsmParser& parser) { - return mlir::failure(); +void PrintConstraints(AsmPrinter& p, + ArrayRef> constraints) { + for (const auto& [constrained_expression, range] : constraints) { + p << constrained_expression << " in " << range << "\n"; + } } -void PrintConstraints( - mlir::AsmPrinter& p, - mlir::ArrayRef> - range_vars) {} +mlir::FailureOr>> ParseConstraints( + AsmParser& parser, + ArrayRef> symbolSet) { + SmallVector> constraints; + while (failed(parser.parseOptionalGreater())) { + auto& constraint = constraints.emplace_back(); + if (parser.parseAffineExpr(symbolSet, constraint.first) || + parser.parseKeyword("in") || ParseInterval(parser, constraint.second)) { + return failure(); + } + } + return constraints; +} + +mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { + mlir::AffineMap map; + if (parser.parseLess() || parser.parseAffineMap(map)) { + return {}; + } + + // Store real strings to back up StringRef throughout ParseConstraints. + SmallVector dim_strings(map.getNumDims()); + SmallVector symbol_strings(map.getNumSymbols()); + SmallVector> symbolSet; + symbolSet.reserve(map.getNumDims() + map.getNumSymbols()); + for (int i = 0; i < map.getNumDims(); ++i) { + dim_strings[i] = absl::StrFormat("d%d", i); + symbolSet.push_back( + {dim_strings[i], mlir::getAffineDimExpr(i, parser.getContext())}); + } + for (int i = 0; i < map.getNumSymbols(); ++i) { + symbol_strings[i] = absl::StrFormat("s%d", i); + symbolSet.push_back( + {symbol_strings[i], mlir::getAffineSymbolExpr(i, parser.getContext())}); + } + + if (parser.parseKeyword("domain") || parser.parseColon()) { + return {}; + } + auto maybe_dim_vars = ParseDimVars(parser, dim_strings); + if (failed(maybe_dim_vars)) { + return {}; + } + + auto maybe_range_vars = ParseRangeVars(parser, symbol_strings); + if (failed(maybe_range_vars)) { + return {}; + } + + auto maybe_constraints = ParseConstraints(parser, symbolSet); + if (failed(maybe_constraints)) { + return {}; + } + // ParseConstraints consumes the > to know when to stop. + return IndexingMapAttr::get(parser.getContext(), map, *maybe_dim_vars, + *maybe_range_vars, *maybe_constraints); +} -mlir::FailureOr< - llvm::SmallVector>> -ParseConstraints(mlir::AsmParser& parser) { - return mlir::failure(); +void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { + printer << "<"; + printer.printStrippedAttrOrType(getMap()); + printer << "\ndomain:\n"; + PrintDimVars(printer, getDimVars()); + PrintRangeVars(printer, getRangeVars()); + PrintConstraints(printer, getConstraints()); + printer << ">"; } } // namespace gpu diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h index bd6cf0424b1db7..fca921621ac4c1 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h @@ -16,35 +16,16 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_ #define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_ +#include "mlir/IR/Attributes.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep namespace xla { namespace gpu { -// Custom printer to print an array of DimVar. -void PrintDimVars(mlir::AsmPrinter& p, mlir::ArrayRef dim_vars); - -// Custom parser to parse an array of DimVar. -mlir::FailureOr> ParseDimVars( - mlir::AsmParser& parser); - -// Custom printer to print an array of RangeVar. -void PrintRangeVars(mlir::AsmPrinter& p, mlir::ArrayRef range_vars); - -// Custom parser to parse an array of RangeVar. -mlir::FailureOr> ParseRangeVars( - mlir::AsmParser& parser); - -// Custom printer to print constraints. -void PrintConstraints( - mlir::AsmPrinter& p, - mlir::ArrayRef<::std::pair<::mlir::AffineExpr, Interval>> range_vars); - -// Custom parser to parse constraints. -mlir::FailureOr>> -ParseConstraints(mlir::AsmParser& parser); +// Custom parser to parse IndexingMapAttr. +mlir::FailureOr ParseIndexingMapAttr(mlir::AsmParser& parser); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td index 51910d27c5a3cc..cf137686b23e4e 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td @@ -29,21 +29,15 @@ def XLAGPU_AffineMapParameter : def XLAGPU_DimVarsParameter : ArrayRefParameter<"::xla::gpu::DimVar", "DimVarArray"> { - let parser = "ParseDimVars($_parser)"; - let printer = "PrintDimVars($_printer, $_self)"; } def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar", "RangeVarArray"> { - let parser = "ParseRangeVars($_parser)"; - let printer = "PrintRangeVars($_printer, $_self)"; } def XLAGPU_ConstraintsParameter : ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>", "ContraintsArray"> { - let parser = "ParseConstraints($_parser)"; - let printer = "PrintConstraints($_printer, $_self)"; } def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { @@ -56,13 +50,7 @@ def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { XLAGPU_DimVarsParameter:$dim_vars, XLAGPU_RangeVarsParameter:$range_vars, XLAGPU_ConstraintsParameter:$constraints); - - let assemblyFormat = [{ - `<` `map` `=` $map `,` - `dim_vars` `=` $dim_vars`,` - `range_vars` `=` $range_vars `,` - `constraints` `=` $constraints `>` - }]; + let hasCustomAssemblyFormat = 1; } #endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_ATTRS diff --git a/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir b/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir new file mode 100644 index 00000000000000..f6228b07aab50f --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir @@ -0,0 +1,118 @@ +// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt | FileCheck %s + +// CHECK: #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) +// CHECK-NEXT: domain: +// CHECK-NEXT: d0 in [1, 2] +// CHECK-NEXT: d1 in [5, 8] +// CHECK-NEXT: d2 in [10, 12] +// CHECK-NEXT: s0 in [0, 32] +// CHECK-NEXT: d0 mod 2 in [0, 1] +// CHECK-NEXT: d0 + s0 in [1, 10] +// CHECK-NEXT: > +#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) + domain: + d0 in [1, 2] + d1 in [5, 8] + d2 in [10, 12] + s0 in [0, 32] + d0 mod 2 in [0, 1] + d0 + s0 in [1, 10] + > + +func.func private @indexing_map_attr(tensor<32xf64, #map>) + +// ----- + +// CHECK: #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) +// CHECK-NEXT: domain: +// CHECK-NEXT: d0 in [1, 2] +// CHECK-NEXT: d1 in [5, 8] +// CHECK-NEXT: s0 in [0, 10] +// CHECK-NEXT: s1 in [0, 5] +// CHECK-NEXT: s2 in [0, 32] +// CHECK-NEXT: d0 mod 2 in [0, 1] +// CHECK-NEXT: d0 + s0 in [1, 10] +// CHECK-NEXT: d1 + s1 + s2 in [1, 32] +// CHECK-NEXT: > +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) + domain: + d0 in [1, 2] + d1 in [5, 8] + s0 in [0, 10] + s1 in [0, 5] + s2 in [0, 32] + d0 mod 2 in [0, 1] + d0 + s0 in [1, 10] + d1 + s1 + s2 in [1, 32] + > +func.func private @more_range_vars(tensor<32xf64, #map>) + +// ----- + +// CHECK: #xla_gpu.indexing_map<(d0)[s0] -> (d0) +// CHECK-NEXT: domain: +// CHECK-NEXT: d0 in [0, 100] +// CHECK-NEXT: s0 in [-3, -1] +// CHECK-NEXT: > +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0) + domain: + d0 in [0, 100] + s0 in [-3, -1] + > +func.func private @indexing_map_small(tensor<100xf64, #map>) + +// ----- + +// CHECK: #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) +// CHECK-NEXT: domain: +// CHECK-NEXT: d0 in [1, 2] +// CHECK-NEXT: d1 in [5, 8] +// CHECK-NEXT: d2 in [10, 12] +// CHECK-NEXT: s0 in [0, 32] +// CHECK-NEXT: > +#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) + domain: + d0 in [1, 2] + d1 in [5, 8] + d2 in [10, 12] + s0 in [0, 32] + > +func.func private @no_constraints(tensor<32xf64, #map>) + +// ----- + +// CHECK: #xla_gpu.indexing_map<()[s0] -> (s0) +// CHECK-NEXT: domain: +// CHECK-NEXT: s0 in [3, 5] +// CHECK-NEXT: s0 mod 2 in [0, 1] +// CHECK-NEXT: > +#map = #xla_gpu.indexing_map<()[s0] -> (s0) + domain: + s0 in [3, 5] + s0 mod 2 in [0, 1] + > +func.func private @no_dimensions(tensor<100xf64, #map>) + +// ----- + +// CHECK: #xla_gpu.indexing_map<(d0) -> (d0) +// CHECK-NEXT: domain: +// CHECK-NEXT: d0 in [3, 5] +// CHECK-NEXT: d0 mod 2 in [0, 1] +// CHECK-NEXT: > +#map = #xla_gpu.indexing_map<(d0) -> (d0) + domain: + d0 in [3, 5] + d0 mod 2 in [0, 1] + > +func.func private @no_symbols(tensor<100xf64, #map>) + +// ----- + +// CHECK: #xla_gpu.indexing_map<() -> () +// CHECK-NEXT: domain: +// CHECK-NEXT: > +#map = #xla_gpu.indexing_map<() -> () + domain: + > +func.func private @empty(tensor<100xf64, #map>) \ No newline at end of file diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index 57b619801fbf5c..ca44e81d66fee6 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -482,6 +482,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/xla/service/gpu/model/indexing_map.h b/xla/service/gpu/model/indexing_map.h index e38e68c1e76179..478e0ecd371bc5 100644 --- a/xla/service/gpu/model/indexing_map.h +++ b/xla/service/gpu/model/indexing_map.h @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" @@ -144,6 +145,11 @@ struct Interval { }; std::ostream& operator<<(std::ostream& out, const Interval& range); +inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const Interval& interval) { + os << absl::StrFormat("[%d, %d]", interval.lower, interval.upper); + return os; +} template H AbslHashValue(H h, const Interval& range) { From 88c80003fa99874a0e8b08dc087e306e58ca1e96 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Thu, 25 Jul 2024 02:38:29 -0700 Subject: [PATCH 138/376] [XLA:GPU] Don't upcast supported fp8 dot operands when is are inside Triton fusion. Keep normalizing fp8 outside of Triton, but in the Triton fused computations, certain operand type combinations are fine. PiperOrigin-RevId: 655882666 --- xla/service/gpu/BUILD | 5 +- xla/service/gpu/fusions/triton/BUILD | 2 + .../gpu/fusions/triton/triton_support.cc | 12 ++ .../gpu/fusions/triton/triton_support.h | 6 + xla/service/gpu/gpu_float_support.cc | 17 +++ xla/service/gpu/gpu_float_support_test.cc | 114 +++++++++++++++++- 6 files changed, 149 insertions(+), 7 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 6d1109cd14abd4..28ca8469086ce3 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2795,8 +2795,8 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:float_support", + "//xla/service/gpu/fusions/triton:triton_support", "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", ], ) @@ -4903,7 +4903,9 @@ xla_cc_test( name = "gpu_float_support_test", srcs = ["gpu_float_support_test.cc"], deps = [ + ":backend_configs_cc", ":gpu_float_support", + ":ir_emission_utils", "//xla:shape_util", "//xla:test_helpers", "//xla:xla_data_proto_cc", @@ -4912,6 +4914,7 @@ xla_cc_test( "//xla/service:hlo_verifier", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", ], diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index ef9d9cfabff86f..e18844d5c3b074 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -399,6 +399,8 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:instruction_fusion", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:variant_visitor", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", diff --git a/xla/service/gpu/fusions/triton/triton_support.cc b/xla/service/gpu/fusions/triton/triton_support.cc index 8f5246d6cd80ac..44c9d51c5921d0 100644 --- a/xla/service/gpu/fusions/triton/triton_support.cc +++ b/xla/service/gpu/fusions/triton/triton_support.cc @@ -32,6 +32,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" #include "xla/primitive_util.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/variant_visitor.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" @@ -646,5 +648,15 @@ CodegenDecision IsTritonSupportedInstruction( return decision; } +bool IsTritonFusedComputation(const HloComputation& computation) { + HloFusionInstruction* fusion = + static_cast(computation.FusionInstruction()); + return fusion != nullptr && + fusion->fusion_kind() == HloInstruction::FusionKind::kCustom && + fusion->backend_config() + ->fusion_backend_config() + .kind() == kTritonGemmFusionKind; +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/triton/triton_support.h b/xla/service/gpu/fusions/triton/triton_support.h index 25024c2c297f5d..abd2a4087216a7 100644 --- a/xla/service/gpu/fusions/triton/triton_support.h +++ b/xla/service/gpu/fusions/triton/triton_support.h @@ -122,6 +122,12 @@ absl::Status EnsureTritonSupportsComputeCapability( CodegenDecision IsTritonSupportedInstruction( const HloInstruction& instr, const se::GpuComputeCapability& gpu_version); +// Returns `true` if the parameter computation is a Triton fused computation, +// i.e. the calling fusion instruction has `FusionKind::kCustom` and +// `backend_config()` with `kind` set to +// `kTritonGemmFusionKind`. +bool IsTritonFusedComputation(const HloComputation& computation); + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/gpu_float_support.cc b/xla/service/gpu/gpu_float_support.cc index 3bae5e6b8e7e72..1403ad021a217d 100644 --- a/xla/service/gpu/gpu_float_support.cc +++ b/xla/service/gpu/gpu_float_support.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/gpu_float_support.h" +#include #include #include "absl/log/check.h" @@ -22,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/float_support.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" @@ -55,6 +57,21 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { case HloOpcode::kReduceScatter: // Handled by Triton GEMM. case HloOpcode::kDot: + using TypeAndCC = std::pair< + PrimitiveType, + stream_executor::CudaComputeCapability::CudaComputeCapabilities>; + for (auto [type, cc] : + {TypeAndCC(F8E4M3FN, se::CudaComputeCapability::AMPERE), + TypeAndCC(F8E5M2, se::CudaComputeCapability::HOPPER)}) { + if (LowPrecisionType() == type) { + auto* cuda_compute_capability = + std::get_if(&compute_capability_); + // Do not normalize supported types inside Triton fused computations. + return cuda_compute_capability && + cuda_compute_capability->IsAtLeast(cc) && + IsTritonFusedComputation(*hlo.parent()); + } + } return LowPrecisionType() == BF16; // Data movement only ops. case HloOpcode::kAllGather: diff --git a/xla/service/gpu/gpu_float_support_test.cc b/xla/service/gpu/gpu_float_support_test.cc index 3f2d34782e9377..1d2f6c167bb090 100644 --- a/xla/service/gpu/gpu_float_support_test.cc +++ b/xla/service/gpu/gpu_float_support_test.cc @@ -15,13 +15,19 @@ limitations under the License. #include "xla/service/gpu/gpu_float_support.h" +#include +#include + #include +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/float_normalization.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_verifier.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -54,11 +60,9 @@ class FloatSupportTest : public HloTestBase { return result.value(); } - void TestDotConversion(PrimitiveType lhs_type, PrimitiveType rhs_type, - PrimitiveType result_type, se::GpuComputeCapability cc, - bool should_convert_lhs, bool should_convert_rhs, - PrimitiveType low_precision_type, - PrimitiveType high_precision_type = F16) { + std::unique_ptr CreateComputation(PrimitiveType lhs_type, + PrimitiveType rhs_type, + PrimitiveType result_type) { auto builder = HloComputation::Builder(TestName()); Shape lhs_shape = ShapeUtil::MakeShape(lhs_type, {3, 3}); Shape rhs_shape = ShapeUtil::MakeShape(rhs_type, {3, 3}); @@ -76,8 +80,17 @@ class FloatSupportTest : public HloTestBase { builder.AddInstruction(HloInstruction::CreateDot( result_shape, a, b, dot_dnums, precision_config)); + return builder.Build(); + } + + void TestDotConversion(PrimitiveType lhs_type, PrimitiveType rhs_type, + PrimitiveType result_type, se::GpuComputeCapability cc, + bool should_convert_lhs, bool should_convert_rhs, + PrimitiveType low_precision_type, + PrimitiveType high_precision_type = F16) { auto module = CreateNewVerifiedModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation( + CreateComputation(lhs_type, rhs_type, result_type)); EXPECT_EQ( Normalize(module.get(), cc, low_precision_type, high_precision_type), @@ -91,6 +104,49 @@ class FloatSupportTest : public HloTestBase { HloOpcode::kConvert, should_convert_rhs); } + + void TestTritonFusedDot(PrimitiveType lhs_type, PrimitiveType rhs_type, + PrimitiveType result_type, + se::GpuComputeCapability cc, bool should_convert_lhs, + bool should_convert_rhs, + PrimitiveType low_precision_type, + PrimitiveType high_precision_type = F16) { + auto module = CreateNewVerifiedModule(); + + auto computation = module->AddComputationAndUnifyNamesAndIds( + CreateComputation(lhs_type, rhs_type, result_type), /*is_entry=*/false); + + Shape lhs_shape = ShapeUtil::MakeShape(lhs_type, {3, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(rhs_type, {3, 3}); + Shape result_shape = ShapeUtil::MakeShape(result_type, {3, 3}); + + auto builder = HloComputation::Builder("main"); + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs_shape, "b")); + HloInstruction* fusion = + builder.AddInstruction(HloInstruction::CreateFusion( + result_shape, HloInstruction::FusionKind::kCustom, {a, b}, + computation)); + GpuBackendConfig config; + config.mutable_fusion_backend_config()->set_kind( + std::string(kTritonGemmFusionKind)); + CHECK_OK(fusion->set_backend_config(config)); + + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ( + Normalize(module.get(), cc, low_precision_type, high_precision_type), + should_convert_lhs || should_convert_rhs); + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kDot); + EXPECT_EQ(computation->root_instruction()->operand(0)->opcode() == + HloOpcode::kConvert, + should_convert_lhs); + EXPECT_EQ(computation->root_instruction()->operand(1)->opcode() == + HloOpcode::kConvert, + should_convert_rhs); + } }; TEST_F(FloatSupportTest, ShouldAlwaysConvertFp8Dot) { @@ -139,6 +195,52 @@ TEST_F(FloatSupportTest, ShouldAlwaysConvertFp8Dot) { /*should_convert_rhs=*/false, F8E5M2); } +TEST_F(FloatSupportTest, ShouldConverTritonUnsupportedFp8Dot) { + TestTritonFusedDot(F8E4M3FN, F8E4M3FN, F16, + se::CudaComputeCapability::Hopper(), + /*should_convert_lhs=*/true, + /*should_convert_rhs=*/true, F8E4M3FN); + + TestTritonFusedDot(F8E4M3FN, F8E4M3FN, F32, + se::CudaComputeCapability::Hopper(), + /*should_convert_lhs=*/false, + /*should_convert_rhs=*/false, F8E4M3FN); + + TestTritonFusedDot(F8E4M3FN, F8E4M3FN, F16, + se::CudaComputeCapability::Ampere(), + /*should_convert_lhs=*/true, + /*should_convert_rhs=*/true, F8E4M3FN); + + TestTritonFusedDot(F8E4M3FN, F8E4M3FN, F32, + se::CudaComputeCapability::Hopper(), + /*should_convert_lhs=*/false, + /*should_convert_rhs=*/false, F8E4M3FN); + + TestTritonFusedDot(F8E5M2, F8E5M2, F16, se::CudaComputeCapability::Ampere(), + /*should_convert_lhs=*/true, + /*should_convert_rhs=*/true, F8E5M2); + + TestTritonFusedDot(F8E5M2, F8E5M2, F32, se::CudaComputeCapability::Ampere(), + /*should_convert_lhs=*/true, + /*should_convert_rhs=*/true, F8E5M2); + + TestTritonFusedDot(F8E5M2, F8E4M3FN, F16, se::CudaComputeCapability::Hopper(), + /*should_convert_lhs=*/true, + /*should_convert_rhs=*/false, F8E5M2); + + TestTritonFusedDot(F8E5M2, F8E4M3FN, F32, se::CudaComputeCapability::Hopper(), + /*should_convert_lhs=*/false, + /*should_convert_rhs=*/false, F8E5M2); + + TestTritonFusedDot(F8E5M2, F16, F16, se::CudaComputeCapability::Hopper(), + /*should_convert_lhs=*/true, + /*should_convert_rhs=*/false, F8E5M2); + + TestTritonFusedDot(F8E5M2, F16, F32, se::CudaComputeCapability::Hopper(), + /*should_convert_lhs=*/true, + /*should_convert_rhs=*/false, F8E5M2); +} + TEST_F(FloatSupportTest, ShouldKeepBf16OnAmpere) { TestDotConversion(BF16, BF16, F32, se::CudaComputeCapability::Ampere(), /*should_convert_lhs=*/false, From 7e97f67436c891141291c8f0577d1556d7e6e32e Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 25 Jul 2024 02:47:56 -0700 Subject: [PATCH 139/376] [XLA:GPU][MLIR-based emitters] Vectorize concats in concatenate emitter. PiperOrigin-RevId: 655884629 --- xla/service/gpu/fusions/BUILD | 1 + xla/service/gpu/fusions/concatenate_mlir.cc | 33 ++++++++++++++++--- xla/service/gpu/fusions/concatenate_mlir.h | 6 ++-- .../gpu/fusions/concatenate_mlir_test.cc | 28 ++++++++++++++++ 4 files changed, 61 insertions(+), 7 deletions(-) diff --git a/xla/service/gpu/fusions/BUILD b/xla/service/gpu/fusions/BUILD index d72b156c033754..f9edabbad5df6f 100644 --- a/xla/service/gpu/fusions/BUILD +++ b/xla/service/gpu/fusions/BUILD @@ -958,6 +958,7 @@ cc_library( hdrs = ["concatenate_mlir.h"], deps = [ ":concatenate", + ":loop", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:launch_dimensions", diff --git a/xla/service/gpu/fusions/concatenate_mlir.cc b/xla/service/gpu/fusions/concatenate_mlir.cc index 23dc9aa6a066b7..f2cecc5d6cac80 100644 --- a/xla/service/gpu/fusions/concatenate_mlir.cc +++ b/xla/service/gpu/fusions/concatenate_mlir.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -35,6 +36,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/concatenate.h" +#include "xla/service/gpu/fusions/loop.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -44,14 +46,36 @@ limitations under the License. namespace xla { namespace gpu { +namespace { using llvm::SmallVector; using mlir::Value; using mlir::ValueRange; +// Computes the unroll factor that divides concat dimension of all operands. +int ComputeUnrollFactor(const HloFusionAnalysis& analysis, + int unroll_factor_for_the_largest_shape) { + auto& concat = analysis.fusion_hero(0).instruction(); + int unroll_factor = unroll_factor_for_the_largest_shape; + int64_t dim = concat.concatenate_dimension(); + for (const HloInstruction* operand : concat.operands()) { + if (unroll_factor == 1) return 1; + unroll_factor = std::gcd(unroll_factor, operand->shape().dimensions(dim)); + } + return unroll_factor; +} + +} // namespace + +MlirConcatenateFusion::MlirConcatenateFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), + largest_shape_(GetLargestConcatOperandShape(analysis_)), + config_(ComputeLoopFusionConfig(analysis_, largest_shape_)), + unroll_factor_(ComputeUnrollFactor(analysis_, config_.unroll_factor)) {} + LaunchDimensions MlirConcatenateFusion::launch_dimensions() const { - return CalculateLaunchDimensions(GetLargestConcatOperandShape(analysis_), - analysis_.device_info()); + return CalculateLaunchDimensions(largest_shape_, analysis_.device_info(), + config_); } std::optional @@ -65,9 +89,8 @@ MlirConcatenateFusion::ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, mlir::MLIRContext* ctx) const { // TODO(b/331356433): Add constraints depending on the `hero_operand_index`. - return GetDefaultThreadIdIndexingMap(launch_dimensions(), /*unroll_factor=*/1, - GetLargestConcatOperandShape(analysis_), - ctx); + return GetDefaultThreadIdIndexingMap(launch_dimensions(), unroll_factor_, + largest_shape_, ctx); } std::vector diff --git a/xla/service/gpu/fusions/concatenate_mlir.h b/xla/service/gpu/fusions/concatenate_mlir.h index ca8c1ec0dd1809..b98db45690389c 100644 --- a/xla/service/gpu/fusions/concatenate_mlir.h +++ b/xla/service/gpu/fusions/concatenate_mlir.h @@ -37,8 +37,7 @@ namespace gpu { class MlirConcatenateFusion : public MlirFusionEmitterBase { public: - explicit MlirConcatenateFusion(const HloFusionAnalysis& analysis) - : analysis_(analysis) {} + explicit MlirConcatenateFusion(const HloFusionAnalysis& analysis); LaunchDimensions launch_dimensions() const override; @@ -62,6 +61,9 @@ class MlirConcatenateFusion : public MlirFusionEmitterBase { private: const HloFusionAnalysis& analysis_; + Shape largest_shape_; + LaunchDimensionsConfig config_; + int unroll_factor_; }; } // namespace gpu diff --git a/xla/service/gpu/fusions/concatenate_mlir_test.cc b/xla/service/gpu/fusions/concatenate_mlir_test.cc index d969fe19cbda8f..d6ca14610c51c3 100644 --- a/xla/service/gpu/fusions/concatenate_mlir_test.cc +++ b/xla/service/gpu/fusions/concatenate_mlir_test.cc @@ -238,6 +238,34 @@ TEST_F(MlirConcatenateFusionTest, EpilogueBitcast) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } +TEST_F(MlirConcatenateFusionTest, Vectorization) { + auto kHloString = R"( + HloModule module + + fused_computation { + param0 = f32[640002] parameter(0) + param1 = f32[640000] parameter(1) + ROOT concat = f32[1280002] concatenate(param0, param1), dimensions={0} + } + ENTRY main { + param0 = f32[640002] parameter(0) + param1 = f32[640000] parameter(1) + ROOT fusion = f32[1280002] fusion(param0, param1), calls=fused_computation, kind=kLoop + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-DAG: affine_map<(d0, d1) -> (d1 * 128 + d0)> + // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0)> + // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002)> + + // CHECK-LABEL: fused_computation + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK-COUNT-2: scf.for %{{.*}} = %[[C0]] to %[[C2]] + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + } // namespace } // namespace gpu } // namespace xla From 7027ff2899c587578e1ecc23546cc7b879c4cd8e Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Thu, 25 Jul 2024 03:08:53 -0700 Subject: [PATCH 140/376] Return vector of cutlass kernels instead of single kernel. This is in preparation to support selecting the best kernel from multiple kernels via autotuning. PiperOrigin-RevId: 655889266 --- xla/service/gpu/kernels/BUILD | 11 +++++++ .../gpu/kernels/cutlass_gemm_custom_kernel.cc | 31 ++++++++++--------- .../gpu/kernels/cutlass_gemm_custom_kernel.h | 5 +-- .../cutlass_gemm_custom_kernel_benchmarks.cc | 17 +++++----- .../cutlass_gemm_custom_kernel_stub.cc | 3 +- .../cutlass_gemm_custom_kernel_test.cc | 18 ++++++----- .../gpu/kernels/cutlass_gemm_fusion.cc | 15 +++------ .../gpu/kernels/cutlass_gemm_fusion_test.cc | 5 +-- 8 files changed, 61 insertions(+), 44 deletions(-) diff --git a/xla/service/gpu/kernels/BUILD b/xla/service/gpu/kernels/BUILD index 9d04094c5fd7cc..6e0a0d44d1523b 100644 --- a/xla/service/gpu/kernels/BUILD +++ b/xla/service/gpu/kernels/BUILD @@ -412,6 +412,17 @@ cuda_library( ]), ) +cuda_library( + name = "cutlass_gemm_kernel_bf16xbf16_to_f32", + srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc"]), + copts = ["-Wno-unknown-attributes"], + deps = if_cuda_is_configured([ + ":cutlass_gemm_adaptor", + "@local_config_cuda//cuda:cuda_headers", + "@cutlass_archive//:cutlass", + ]), +) + #===--------------------------------------------------------------------------------------------===# # CUTLASS Gemm kernel libraries #===--------------------------------------------------------------------------------------------===# diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc index b97b8939b6f740..ae39cfbe293d1d 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include "absl/log/log.h" #include "absl/status/status.h" @@ -171,12 +172,12 @@ KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k, //===----------------------------------------------------------------------===// template -static absl::StatusOr Load(std::string name, int32_t m, int32_t n, - int32_t k, const ArgsIndices& indices, - const DynamicSliceIndices& slices, - const se::DeviceDescription& device, - Adaptor adaptor = {}, - DeviceKernel kernel = {}) { +static CustomKernel Load(std::string name, int32_t m, int32_t n, int32_t k, + const ArgsIndices& indices, + const DynamicSliceIndices& slices, + const se::DeviceDescription& device, + Adaptor adaptor = {}, + DeviceKernel kernel = {}) { // Get the dispatch grid size and shared memory requirements. auto cluster_dim = As(adaptor.ClusterDim()); auto block_dim = As(adaptor.BlockDim(m, n, k)); @@ -198,7 +199,7 @@ static absl::StatusOr Load(std::string name, int32_t m, int32_t n, } } -absl::StatusOr GetCutlassGemmKernel( +absl::StatusOr> GetCutlassGemmKernels( std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices, const se::DeviceDescription& device) { @@ -207,21 +208,21 @@ absl::StatusOr GetCutlassGemmKernel( switch (dtype) { case PrimitiveType::F32: - return Load>(std::move(name), m, n, k, indices, - slices, device); + return {{Load>(std::move(name), m, n, k, indices, + slices, device)}}; case PrimitiveType::BF16: #if CUDA_VERSION >= 12000 if (cuda_cc.IsAtLeastHopper()) { - return Load>(std::move(name), m, n, k, indices, - slices, device); + return {{Load>(std::move(name), m, n, k, indices, + slices, device)}}; } #endif if (cuda_cc.IsAtLeastAmpere()) { - return Load>(std::move(name), m, n, k, indices, - slices, device); + return {{Load>(std::move(name), m, n, k, indices, + slices, device)}}; } - return Load>(std::move(name), m, n, k, indices, - slices, device); + return {{Load>(std::move(name), m, n, k, indices, + slices, device)}}; default: return absl::InvalidArgumentError("Unsupported CUTLASS gemm data type"); diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h index e4323730927524..37531ef0038f31 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/status/statusor.h" #include "xla/service/gpu/kernels/custom_kernel.h" @@ -27,8 +28,8 @@ limitations under the License. namespace xla::gpu::kernel::gemm_universal { -// Returns a pre-compiled custom kernel for a given data type and problem size. -absl::StatusOr GetCutlassGemmKernel( +// Returns pre-compiled custom kernels for a given data type and problem size. +absl::StatusOr> GetCutlassGemmKernels( std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices, const se::DeviceDescription& device); diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc index 4175412335ee84..8d44bb024294e3 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc @@ -19,6 +19,7 @@ limitations under the License. #include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/platform.h" @@ -52,13 +53,15 @@ static void BM_RowMajorGemm(benchmark::State& state) { int32_t n = 16384; int32_t k = 4096; - auto custom_kernel = - GetCutlassGemmKernel("cutlass_gemm", PrimitiveType::BF16, m, n, k, - /*indices=*/{0, 1, 2}, /*slices=*/{}, device); + TF_ASSERT_OK_AND_ASSIGN( + auto custom_kernels, + GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::BF16, m, n, k, + /*indices=*/{0, 1, 2}, /*slices=*/{}, device)); + const auto& custom_kernel = custom_kernels[0]; TF_ASSERT_OK_AND_ASSIGN( auto gemm, - se::KernelFactory::Create(executor, custom_kernel->kernel_spec())); + se::KernelFactory::Create(executor, custom_kernel.kernel_spec())); // Prepare arguments: a=1.1, b=1.2, c=0.0 se::DeviceMemory a = executor->AllocateArray(m * k, 0); @@ -71,11 +74,11 @@ static void BM_RowMajorGemm(benchmark::State& state) { se::KernelArgsDeviceMemoryArray args( std::vector({a, b, c}), - custom_kernel->shared_memory_bytes()); + custom_kernel.shared_memory_bytes()); for (auto s : state) { - TF_CHECK_OK(stream->Launch(custom_kernel->thread_dims(), - custom_kernel->block_dims(), *gemm, args)); + TF_CHECK_OK(stream->Launch(custom_kernel.thread_dims(), + custom_kernel.block_dims(), *gemm, args)); TF_CHECK_OK(stream->BlockHostUntilDone()); } } diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc index 88e11184beb1fc..8e231ee3b8e6e9 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "absl/status/statusor.h" #include "xla/service/gpu/kernels/custom_kernel.h" @@ -24,7 +25,7 @@ limitations under the License. namespace xla::gpu::kernel::gemm_universal { -absl::StatusOr GetCutlassGemmKernel( +absl::StatusOr> GetCutlassGemmKernels( std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices, const se::DeviceDescription& device) { diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index 9d94ed77dfb275..458f31ae88a836 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/platform.h" @@ -42,13 +43,16 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { auto stream = executor->CreateStream().value(); // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. - auto custom_kernel = GetCutlassGemmKernel( - "cutlass_gemm", PrimitiveType::F32, 4, 4, 4, - /*indices=*/{0, 1, 2}, /*slices=*/{}, executor->GetDeviceDescription()); + TF_ASSERT_OK_AND_ASSIGN( + auto custom_kernels, + GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::F32, 4, 4, 4, + /*indices=*/{0, 1, 2}, /*slices=*/{}, + executor->GetDeviceDescription())); + auto custom_kernel = custom_kernels[0]; TF_ASSERT_OK_AND_ASSIGN( auto gemm, - se::KernelFactory::Create(executor, custom_kernel->kernel_spec())); + se::KernelFactory::Create(executor, custom_kernel.kernel_spec())); int64_t length = 4 * 4; int64_t byte_length = sizeof(float) * length; @@ -69,9 +73,9 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { // Launch gemm kernel with device memory arguments. se::KernelArgsDeviceMemoryArray arr( std::vector({a, b, c}), - custom_kernel->shared_memory_bytes()); - TF_ASSERT_OK(stream->Launch(custom_kernel->thread_dims(), - custom_kernel->block_dims(), *gemm, arr)); + custom_kernel.shared_memory_bytes()); + TF_ASSERT_OK(stream->Launch(custom_kernel.thread_dims(), + custom_kernel.block_dims(), *gemm, arr)); // Copy `c` data back to host. std::vector dst(length, -1.0f); diff --git a/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index e57cb2425fb4a0..a392801e25c578 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -300,11 +300,8 @@ class CutlassGemmFusion : public CustomKernelFusion { size_t k = lhs_shape.dimensions(1); size_t n = rhs_shape.dimensions(1); - TF_ASSIGN_OR_RETURN( - auto kernel, - kernel::gemm_universal::GetCutlassGemmKernel( - "cutlass_gemm", dtype, m, n, k, indices, /*slices=*/{}, device)); - return std::vector{std::move(kernel)}; + return kernel::gemm_universal::GetCutlassGemmKernels( + "cutlass_gemm", dtype, m, n, k, indices, /*slices=*/{}, device); } }; @@ -380,11 +377,9 @@ class CutlassGemmWithDynamicUpdateSliceFusion : public CustomKernelFusion { size_t k = lhs_shape.dimensions(1); size_t n = rhs_shape.dimensions(1); - TF_ASSIGN_OR_RETURN( - auto kernel, kernel::gemm_universal::GetCutlassGemmKernel( - "cutlass_gemm_with_dynamic_update_slice", dtype, m, n, - k, args_indices, slices, device)); - return std::vector{std::move(kernel)}; + return kernel::gemm_universal::GetCutlassGemmKernels( + "cutlass_gemm_with_dynamic_update_slice", dtype, m, n, k, args_indices, + slices, device); } }; diff --git a/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index 59a6363dd44d81..1fec453765c4d8 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -43,11 +43,12 @@ class CutlassFusionTest : public HloTestBase { } int CutlassGemmKernelSharedMemorySize(PrimitiveType dtype, int m, int n, int k) { - return kernel::gemm_universal::GetCutlassGemmKernel( + return kernel::gemm_universal::GetCutlassGemmKernels( "cutlass_gemm", dtype, m, n, k, /*indices=*/{0, 1, 2}, /*slices=*/{}, backend().default_stream_executor()->GetDeviceDescription()) - ->shared_memory_bytes(); + ->at(0) + .shared_memory_bytes(); }; }; From 0663be1ea4a99294b25b777c1db7d745fbf0fde6 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 25 Jul 2024 04:11:23 -0700 Subject: [PATCH 141/376] [XLA:GPU] Print Interval as a closed interval. Replaced by %s/in \[\(\d\+\), \(\d\+\))/\='in ['.(submatch(1)).', '.(submatch(2)-1).']'/g PiperOrigin-RevId: 655903009 --- docs/indexing.md | 286 ++--- .../gpu/fusions/concatenate_mlir_test.cc | 18 +- xla/service/gpu/fusions/concatenate_test.cc | 18 +- ...in_place_dynamic_update_slice_mlir_test.cc | 24 +- .../in_place_dynamic_update_slice_test.cc | 16 +- .../gpu/fusions/input_slices_mlir_test.cc | 36 +- xla/service/gpu/fusions/input_slices_test.cc | 18 +- xla/service/gpu/fusions/loop_mlir_test.cc | 86 +- xla/service/gpu/fusions/loop_test.cc | 86 +- .../mlir/elemental_hlo_to_mlir_test.cc | 70 +- .../gpu/fusions/mlir/ir/xla_gpu_ops.cc | 8 +- .../gpu/fusions/mlir/tests/canonicalize.mlir | 44 +- .../gpu/fusions/mlir/tests/invalid.mlir | 2 +- .../gpu/fusions/mlir/tests/lower_tensors.mlir | 2 +- xla/service/gpu/fusions/mlir/tests/ops.mlir | 12 +- .../fusions/mlir/tests/optimize_loops.mlir | 16 +- .../fusions/mlir/tests/simplify_affine.mlir | 8 +- .../mlir/tests/vectorize_loads_stores.mlir | 18 +- .../gpu/fusions/reduction_mlir_test.cc | 42 +- xla/service/gpu/fusions/reduction_test.cc | 34 +- xla/service/gpu/fusions/scatter_mlir_test.cc | 38 +- xla/service/gpu/fusions/scatter_test.cc | 38 +- .../gpu/fusions/transpose_mlir_test.cc | 172 +-- xla/service/gpu/fusions/transpose_test.cc | 160 +-- xla/service/gpu/model/coalescing_analysis.cc | 8 +- .../gpu/model/indexing_analysis_test.cc | 1026 ++++++++--------- xla/service/gpu/model/indexing_map.cc | 3 +- xla/service/gpu/model/indexing_map_test.cc | 336 +++--- .../gpu/model/symbolic_tile_analysis_test.cc | 66 +- xla/service/gpu/model/symbolic_tile_test.cc | 54 +- 30 files changed, 1372 insertions(+), 1373 deletions(-) diff --git a/docs/indexing.md b/docs/indexing.md index 7d1722de431beb..c44f2845358ee6 100644 --- a/docs/indexing.md +++ b/docs/indexing.md @@ -15,7 +15,7 @@ bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1} ``` the indexing map from the output to input is `(i, j, k) -> (j)` for `i in -[0, 10]`, `j in [0, 21)` and `k in [0, 31)`. +[0, 10]`, `j in [0, 20]` and `k in [0, 30]`. ## Motivation @@ -60,8 +60,8 @@ For example, if we have a reduction from `tensor<2x4x8x16xf32>` to `(d0, d1) -> (r0, d0, d1, r1)`, where `d_i` are the dimension variables that correspond to the indices of the output tensor. Range variables `r_j` encode multiple values, i.e. to compute a `(d0, d1)` element of the output, we need -`(r0, d0, d1, r1)` elements of the input, where `r0 in [0, 2)` and -`r1 in [0, 16)`. +`(r0, d0, d1, r1)` elements of the input, where `r0 in [0, 1]` and +`r1 in [0, 15]`. This mapping can be constructed from the attributes of HLO instructions or the mappings of unfused instructions can be composed to get indexing for a fusion. @@ -166,8 +166,8 @@ The output to input maps: ``` (d0, d1) -> (d0, d1) domain: -d0 in [0, 10) -d1 in [0, 20) +d0 in [0, 9] +d1 in [0, 19] ``` The input to output maps @@ -177,8 +177,8 @@ The input to output maps ``` (d0, d1) -> (d0, d1) domain: -d0 in [0, 10) -d1 in [0, 20) +d0 in [0, 9] +d1 in [0, 19] ``` ### [Broadcast](https://openxla.org/xla/operation_semantics#broadcastindim) @@ -196,9 +196,9 @@ The output to input map: ``` (d0, d1, d2) -> (d1) domain: -d0 in [0, 10) -d1 in [0, 20) -d2 in [0, 30) +d0 in [0, 9] +d1 in [0, 19] +d2 in [0, 29] ``` The input to output map @@ -206,9 +206,9 @@ The input to output map ``` (d0)[s0, s1] -> (s0, d0, s1) domain: -d0 in [0, 20) -s0 in [0, 10) -s1 in [0, 30) +d0 in [0, 19] +s0 in [0, 9] +s1 in [0, 29] ``` Note that now we have **s** on the right side for the input-to-output @@ -235,16 +235,16 @@ The output to input map for `src`: ``` (d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2) domain: -d0 in [0, 1) -d1 in [0, 2) -d2 in [0, 32) -s0 in [0, 2) +d0 in [0, 0] +d1 in [0, 1] +d2 in [0, 31] +s0 in [0, 1] hlo: of1 = s32[] parameter(1) (d0, d1, d2) -> () -s1 in [0, 1) +s1 in [0, 0] hlo: of2 = s32[] parameter(2) (d0, d1, d2) -> () -s2 in [0, 227) +s2 in [0, 226] hlo: of3 = s32[] parameter(3) (d0, d1, d2) -> () ``` @@ -260,9 +260,9 @@ The output to input map for `of1`, `of2` and `of3`: ``` (d0, d1, d2) -> () domain: -d0 in [0, 1) -d1 in [0, 2) -d2 in [0, 32) +d0 in [0, 0] +d1 in [0, 1] +d2 in [0, 31] ``` ### [DynamicUpdateSlice](https://openxla.org/xla/operation_semantics#dynamicupdateslice) @@ -281,20 +281,20 @@ do not support inqequality constraints. ``` (d0, d1) -> (d0, d1) domain: -d0 in [0, 20) -d1 in [0, 30) +d0 in [0, 19] +d1 in [0, 29] ``` The output to input map for `upd`: ``` (d0, d1)[s0, s1] -> (d0 - s0, d1 - s1) domain: -d0 in [0, 20) -d1 in [0, 30) -s0 in [0, 16) +d0 in [0, 19] +d1 in [0, 29] +s0 in [0, 15] hlo: of1 = s32[] parameter(2) (d0, d1) -> () -s1 in [0, 21) +s1 in [0, 20] hlo: of2 = s32[] parameter(3) (d0, d1) -> () ``` @@ -311,8 +311,8 @@ The output to input map for `of1` and `of2`: ``` (d0, d1) -> () domain: -d0 in [0, 20) -d1 in [0, 30) +d0 in [0, 19] +d1 in [0, 29] ``` ### [Gather](https://openxla.org/xla/operation_semantics#gather) @@ -334,14 +334,14 @@ The output to input map for `operand`: (d0, d1, d2, d3)[s0, s1] -> (d1 + s0, d2 + s1, d3) domain: -d0 in [0, 1806) -d1 in [0, 7) -d2 in [0, 8) -d3 in [0, 4) -s0 in [0, 27) +d0 in [0, 1805] +d1 in [0, 6] +d2 in [0, 7] +d3 in [0, 3] +s0 in [0, 26] hlo: indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 0) -s1 in [0, 69) +s1 in [0, 68] hlo: indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 1) ``` @@ -356,11 +356,11 @@ The output to input map for `indices`: ``` (d0, d1, d2, d3)[s0] -> (d0, s0) domain: - d0 in [0, 1806) - d1 in [0, 7) - d2 in [0, 8) - d3 in [0, 4) - s0 in [0, 2) + d0 in [0, 1805] + d1 in [0, 6] + d2 in [0, 7] + d3 in [0, 3] + s0 in [0, 1] ``` The range variable `s0` shows that we need the entire row (d0, *) of the `indices` tensor to compute an element of the output. @@ -380,10 +380,10 @@ The output to input map: ``` (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: -d0 in [0, 3) -d1 in [0, 6) -d2 in [0, 128) -d3 in [0, 12288) +d0 in [0, 2] +d1 in [0, 5] +d2 in [0, 127] +d3 in [0, 12287] ``` The input to output map: @@ -391,10 +391,10 @@ The input to output map: ``` (d0, d1, d2, d3) -> (d0, d2, d3, d1) domain: -d0 in [0, 3) -d1 in [0, 12288) -d2 in [0, 6) -d3 in [0, 128) +d0 in [0, 2] +d1 in [0, 12287] +d2 in [0, 5] +d3 in [0, 127] ``` ### [Reverse](https://openxla.org/xla/operation_semantics#rev_reverse) @@ -412,10 +412,10 @@ The output to input map: ``` (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: -d0 in [0, 1) -d1 in [0, 17) -d2 in [0, 9) -d3 in [0, 9) +d0 in [0, 0] +d1 in [0, 16] +d2 in [0, 8] +d3 in [0, 8] ``` The input to output map: @@ -423,10 +423,10 @@ The input to output map: ``` (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: -d0 in [0, 1) -d1 in [0, 17) -d2 in [0, 9) -d3 in [0, 9) +d0 in [0, 0] +d1 in [0, 16] +d2 in [0, 8] +d3 in [0, 8] ``` ### **[(Variadic)Reduce](https://openxla.org/xla/operation_semantics#reduce)** @@ -451,8 +451,8 @@ The output to input maps: ``` (d0)[s0] -> (s0, d0) domain: -d0 in [0, 10) -s0 in [0, 256) +d0 in [0, 9] +s0 in [0, 255] ``` - output -> init_j: @@ -460,7 +460,7 @@ s0 in [0, 256) ``` (d0) -> () domain: -d0 in [0, 10) +d0 in [0, 9] ``` The input to output maps: @@ -470,8 +470,8 @@ The input to output maps: ``` (d0, d1) -> (d1) domain: -d0 in [0, 256) -d1 in [0, 10) +d0 in [0, 255] +d1 in [0, 9] ``` - init_i -> output_j: @@ -479,7 +479,7 @@ d1 in [0, 10) ``` ()[s0] -> (s0) domain: -s0 in [0, 10) +s0 in [0, 9] ``` for i, j = 0, ... INPUT_COUNT. @@ -501,9 +501,9 @@ The output to input map: ``` (d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2) domain: -d0 in [0, 5) -d1 in [0, 3) -d2 in [0, 25) +d0 in [0, 4] +d1 in [0, 2] +d2 in [0, 24] ``` The input to output map: @@ -511,11 +511,11 @@ The input to output map: ``` (d0, d1, d2) -> (d0 - 5, (d1 - 3) floordiv 7, d2 floordiv 2) domain: -d0 in [5, 10) -d1 in [3, 18) -d2 in [0, 49) -(d1 - 3) mod 7 in [0, 1) -d2 mod 2 in [0, 1) +d0 in [5, 9] +d1 in [3, 17] +d2 in [0, 48] +(d1 - 3) mod 7 in [0, 0] +d2 mod 2 in [0, 0] ``` ### [Reshape](https://openxla.org/xla/operation_semantics#reshape) @@ -536,7 +536,7 @@ The output to input map: ``` (d0) -> (d0 floordiv 8, d0 mod 8) domain: -d0 in [0, 32) +d0 in [0, 31] ``` The input to output map: @@ -544,8 +544,8 @@ The input to output map: ``` (d0, d1) -> (d0 * 8 + d1) domain: -d0 in [0, 4) -d1 in [0, 8) +d0 in [0, 3] +d1 in [0, 7] ``` #### Expand shape @@ -562,8 +562,8 @@ The output to input map: ``` (d0, d1) -> (d0 * 8 + d1) domain: -d0 in [0, 4) -d1 in [0, 8) +d0 in [0, 3] +d1 in [0, 7] ``` The input to output map: @@ -571,7 +571,7 @@ The input to output map: ``` (d0) -> (d0 floordiv 8, d0 mod 8) domain: -d0 in [0, 32) +d0 in [0, 31] ``` #### Generic reshape @@ -596,9 +596,9 @@ The output to input map: ``` (d0, d1, d2) -> (d0 * 2 + d1 floordiv 2, d2 + (d1 mod 2) * 4) domain: -d0 in [0, 2) -d1 in [0, 4) -d2 in [0, 4) +d0 in [0, 1] +d1 in [0, 3] +d2 in [0, 3] ``` The input to output map: @@ -606,8 +606,8 @@ The input to output map: ``` (d0, d1) -> (d0 floordiv 2, d1 floordiv 4 + (d0 mod 2) * 2, d1 mod 4) domain: -d0 in [0, 4) -d1 in [0, 8) +d0 in [0, 3] +d1 in [0, 7] ``` ##### Example 2: Expanded and collapsed subshapes @@ -627,9 +627,9 @@ The output to input map: ``` (d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2) domain: -d0 in [0, 32) -d1 in [0, 3) -d2 in [0, 4) +d0 in [0, 31] +d1 in [0, 2] +d2 in [0, 3] ``` The input to output map: @@ -637,9 +637,9 @@ The input to output map: ``` (d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4) domain: -d0 in [0, 4) -d1 in [0, 8) -d2 in [0, 12) +d0 in [0, 3] +d1 in [0, 7] +d2 in [0, 11] ``` ### Bitcast @@ -668,9 +668,9 @@ The output to inputs maps: ``` (d0, d1, d2) -> (d0, d1, d2) domain: -d0 in [0, 2) -d1 in [0, 5) -d2 in [0, 7) +d0 in [0, 1] +d1 in [0, 4] +d2 in [0, 6] ``` - output -> input 2: @@ -678,9 +678,9 @@ d2 in [0, 7) ``` (d0, d1, d2) -> (d0, d1 - 5, d2) domain: -d0 in [0, 2) -d1 in [5, 16) -d2 in [0, 7) +d0 in [0, 1] +d1 in [5, 15] +d2 in [0, 6] ``` - output -> input 3: @@ -688,9 +688,9 @@ d2 in [0, 7) ``` (d0, d1, d2) -> (d0, d1 - 16, d2) domain: -d0 in [0, 2) -d1 in [16, 33) -d2 in [0, 7) +d0 in [0, 1] +d1 in [16, 32] +d2 in [0, 6] ``` @@ -701,9 +701,9 @@ The inputs to output maps: ``` (d0, d1, d2) -> (d0, d1, d2) domain: -d0 in [0, 2) -d1 in [0, 5) -d2 in [0, 7) +d0 in [0, 1] +d1 in [0, 4] +d2 in [0, 6] ``` - input 2 -> output: @@ -711,9 +711,9 @@ d2 in [0, 7) ``` (d0, d1, d2) -> (d0, d1 + 5, d2) domain: -d0 in [0, 2) -d1 in [0, 11) -d2 in [0, 7) +d0 in [0, 1] +d1 in [0, 10] +d2 in [0, 6] ``` - input 3 -> output: @@ -721,9 +721,9 @@ d2 in [0, 7) ``` (d0, d1, d2) -> (d0, d1 + 16, d2) domain: -d0 in [0, 2) -d1 in [0, 17) -d2 in [0, 7) +d0 in [0, 1] +d1 in [0, 16] +d2 in [0, 6] ``` ### [Dot](https://openxla.org/xla/operation_semantics#dot) @@ -745,10 +745,10 @@ The output to inputs maps: ``` (d0, d1, d2)[s0] -> (d0, d1, s0) domain: -d0 in [0, 4) -d1 in [0, 128) -d2 in [0, 64) -s0 in [0, 256) +d0 in [0, 3] +d1 in [0, 127] +d2 in [0, 63] +s0 in [0, 255] ``` - output -> input_2: @@ -756,10 +756,10 @@ s0 in [0, 256) ``` (d0, d1, d2)[s0] -> (d0, s0, d2) domain: -d0 in [0, 4) -d1 in [0, 128) -d2 in [0, 64) -s0 in [0, 256) +d0 in [0, 3] +d1 in [0, 127] +d2 in [0, 63] +s0 in [0, 255] ``` The inputs to output maps: @@ -769,10 +769,10 @@ The inputs to output maps: ``` (d0, d1, d2)[s0] -> (d0, d1, s0) domain: -d0 in [0, 4) -d1 in [0, 128) -d2 in [0, 256) -s0 in [0, 64) +d0 in [0, 3] +d1 in [0, 127] +d2 in [0, 255] +s0 in [0, 63] ``` - input_2 -> output: @@ -780,10 +780,10 @@ s0 in [0, 64) ``` (d0, d1, d2)[s0] -> (d0, s0, d1) domain: -d0 in [0, 4) -d1 in [0, 256) -d2 in [0, 64) -s0 in [0, 128) +d0 in [0, 3] +d1 in [0, 255] +d2 in [0, 63] +s0 in [0, 127] ``` ### [Pad](https://openxla.org/xla/operation_semantics#pad) @@ -805,9 +805,9 @@ The output to input maps: ``` (d0, d1) -> ((d0 - 1) floordiv 2, d1 - 4) domain: -d0 in [1, 8) -d1 in [4, 8) -(d0 - 1) mod 2 in [0, 1) +d0 in [1, 7] +d1 in [4, 7] +(d0 - 1) mod 2 in [0, 0] ``` - output -> init: @@ -815,8 +815,8 @@ d1 in [4, 8) ``` (d0, d1) -> () domain: -d0 in [0, 12) -d1 in [0, 16) +d0 in [0, 11] +d1 in [0, 15] ``` @@ -841,9 +841,9 @@ The output to input maps: ``` (d0, d1)[s0] -> (d0, d1 + s0) domain: -d0 in [0, 1024) -d1 in [0, 3) -s0 in [0, 512) +d0 in [0, 1023] +d1 in [0, 2] +s0 in [0, 511] ``` - output -> init: @@ -851,8 +851,8 @@ s0 in [0, 512) ``` (d0, d1) -> () domain: -d0 in [0, 1024) -d1 in [0, 3) +d0 in [0, 1023] +d1 in [0, 2] ``` ## Indexing Maps for Fusion @@ -909,10 +909,10 @@ The output-to-input indexing maps for `parameter 0` for softmax: ``` (d0, d1, d2)[s0] -> (d0, d1, s0) domain: -d0 in [0, 2) -d1 in [0, 65) -d2 in [0, 125) -s0 in [0, 125) +d0 in [0, 1] +d1 in [0, 64] +d2 in [0, 124] +s0 in [0, 124] ``` and @@ -920,9 +920,9 @@ and ``` (d0, d1, d2) -> (d0, d1, d2) domain: -d0 in [0, 2) -d1 in [0, 65) -d2 in [0, 125) +d0 in [0, 1] +d1 in [0, 64] +d2 in [0, 124] ``` where `s0` refers to the inner-most dimension of the input. @@ -941,10 +941,10 @@ The simplifier can rewrite the following expressions. 1. `(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)` for **d** in `[0, 6] x [0, 14]` becomes `(d0, d1) -> (d0, d1)` 2. `(d0, d1, d2) -> ((100d0 + 10d1 + d2) floorDiv 100, ((100d0 + 10d1 + - d2) mod 100) floordiv 10, d2 mod 10)` for `di in [0, 10)` becomes `(d0, d1, + d2) mod 100) floordiv 10, d2 mod 10)` for `di in [0, 9]` becomes `(d0, d1, d2) -> (d0, d1, d2)`. 3. `(d0, d1, d2) -> ((16d0 + 4d1 + d2) floordiv 8, (16d0 + 4d1 + d2) mod - 8)` for `d_i in [0, 10)` becomes `(d0, d1, d2) -> (2d0 + (4d1 + + 8)` for `d_i in [0, 9]` becomes `(d0, d1, d2) -> (2d0 + (4d1 + d2) floordiv 8,(4d1 + d2) mod 8)`. 4. `(d0, d1) -> (-(-11d0 - d1 + 109) floordiv 11 + 9)` for **d** in `[0, 9] x [0, 10]` becomes `(d0, d1) -> (d0)`. @@ -967,8 +967,8 @@ Indexing map simplification also simplifies the constraints. 1. Constraints of type `lower_bound <= affine_expr (floordiv, +, -, *) constant <= upper_bound` are rewritten as `updated_lower_bound <= affine_expr <= updated_upped_bound`. -2. Constraints that are always satisfied, e.g. `d0 + s0 in [0, 21)` -for `d0 in [0, 6)` and `s0 in [1, 4)` are eliminated. +2. Constraints that are always satisfied, e.g. `d0 + s0 in [0, 20]` +for `d0 in [0, 5]` and `s0 in [1, 3]` are eliminated. 3. Affine expressions in the constraints are optimized as the indexing affine map above. diff --git a/xla/service/gpu/fusions/concatenate_mlir_test.cc b/xla/service/gpu/fusions/concatenate_mlir_test.cc index d6ca14610c51c3..c0637cbe12dc74 100644 --- a/xla/service/gpu/fusions/concatenate_mlir_test.cc +++ b/xla/service/gpu/fusions/concatenate_mlir_test.cc @@ -59,15 +59,15 @@ TEST_F(MlirConcatenateFusionTest, ThreadIdIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (bl_x * 128 + th_x) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 4) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - bl_x * 128 + th_x in [0, 400) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 3] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 128 + th_x in [0, 399] )"; auto thread_id_to_output_indexing_0 = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); diff --git a/xla/service/gpu/fusions/concatenate_test.cc b/xla/service/gpu/fusions/concatenate_test.cc index 56c4f765491472..8192b33c37a1cb 100644 --- a/xla/service/gpu/fusions/concatenate_test.cc +++ b/xla/service/gpu/fusions/concatenate_test.cc @@ -85,15 +85,15 @@ TEST_F(ConcatenateTest, ThreadIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (bl_x * 128 + th_x) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 4) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - bl_x * 128 + th_x in [0, 400) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 3] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 128 + th_x in [0, 399] )"; EXPECT_THAT( fusion diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc index c3ef11a04f1de8..b68a95e9516bfd 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc @@ -66,14 +66,14 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( th_x floordiv 6, th_x mod 6) domain: - th_x in [0, 30) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) + th_x in [0, 29] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] )")); auto thread_id_dst_indexing = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); @@ -112,8 +112,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, SimpleDUS) { // CHECK-DAG: %[[C_15:.*]] = arith.constant 15 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x - // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 30)) - // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 30)) + // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 29]) + // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 29]) // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] @@ -162,8 +162,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OutOfBoundDUS) { // CHECK-DAG: %[[C_5:.*]] = arith.constant 5 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x - // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 6)) - // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 6)) + // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 5]) + // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 5]) // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc index 61648b422e73e1..e48cee0cd473b5 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc @@ -88,14 +88,14 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( th_x floordiv 6, th_x mod 6) domain: - th_x in [0, 30) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) + th_x in [0, 29] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] )")); auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); diff --git a/xla/service/gpu/fusions/input_slices_mlir_test.cc b/xla/service/gpu/fusions/input_slices_mlir_test.cc index 73a2e98899f48f..abeb57accdabdf 100644 --- a/xla/service/gpu/fusions/input_slices_mlir_test.cc +++ b/xla/service/gpu/fusions/input_slices_mlir_test.cc @@ -59,15 +59,15 @@ TEST_F(MlirInputSlicesFusionTest, ThreadIndexing) { th_x mod 5 ) domain: - th_x in [5, 20) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1) - bl_y in [0, 1) - bl_z in [0, 1) - s0 in [0, 1) - s1 in [0, 1) - th_x mod 5 in [0, 3) + th_x in [5, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + th_x mod 5 in [0, 2] )")); auto thread_id_to_output_indexing_1 = emitter->ComputeThreadIdToOutputIndexing(1, &mlir_context_); @@ -79,15 +79,15 @@ TEST_F(MlirInputSlicesFusionTest, ThreadIndexing) { th_x mod 5 ) domain: - th_x in [0, 10) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1) - bl_y in [0, 1) - bl_z in [0, 1) - s0 in [0, 1) - s1 in [0, 1) - th_x mod 5 in [0, 3) + th_x in [0, 9] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + th_x mod 5 in [0, 2] )")); } diff --git a/xla/service/gpu/fusions/input_slices_test.cc b/xla/service/gpu/fusions/input_slices_test.cc index f979f6885db1a5..689727aed734ec 100644 --- a/xla/service/gpu/fusions/input_slices_test.cc +++ b/xla/service/gpu/fusions/input_slices_test.cc @@ -87,15 +87,15 @@ TEST_F(InputSlicesTest, ThreadIndexing) { (bl_x * 128 + th_x) mod 3, (bl_x * 128 + th_x) floordiv 6) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 2) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - bl_x * 128 + th_x in [0, 30) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 1] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 128 + th_x in [0, 29] )")); } diff --git a/xla/service/gpu/fusions/loop_mlir_test.cc b/xla/service/gpu/fusions/loop_mlir_test.cc index 906227730f892a..08dcb4df490e54 100644 --- a/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/xla/service/gpu/fusions/loop_mlir_test.cc @@ -59,15 +59,15 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id ) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1008) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 12) - unroll_id in [0, 4) - bl_x * 128 + chunk_id * 129024 + th_x in [0, 1500000) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 1007] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 11] + unroll_id in [0, 3] + bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999] )")); } @@ -97,14 +97,14 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) domain: - th_x in [0, 20) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] )")); auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); @@ -112,14 +112,14 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) domain: - th_x in [0, 20) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] )")); } @@ -153,15 +153,15 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { (bl_x * 128 + th_x) mod 30 ) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 47) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - bl_x * 128 + th_x in [0, 6000) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 128 + th_x in [0, 5999] )")); auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); @@ -170,15 +170,15 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (((bl_x * 128 + th_x) floordiv 30) mod 20) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 47) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - bl_x * 128 + th_x in [0, 6000) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 128 + th_x in [0, 5999] )")); } diff --git a/xla/service/gpu/fusions/loop_test.cc b/xla/service/gpu/fusions/loop_test.cc index 0a487ee43c8cf2..69c41ec0c932b9 100644 --- a/xla/service/gpu/fusions/loop_test.cc +++ b/xla/service/gpu/fusions/loop_test.cc @@ -93,15 +93,15 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id ) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1008) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 12) - unroll_id in [0, 4) - bl_x * 128 + chunk_id * 129024 + th_x in [0, 1500000) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 1007] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 11] + unroll_id in [0, 3] + bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999] )")); } @@ -131,14 +131,14 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) domain: - th_x in [0, 20) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( @@ -147,14 +147,14 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) domain: - th_x in [0, 20) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 1) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) + th_x in [0, 19] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 0] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] )")); } @@ -187,15 +187,15 @@ TEST_F(LoopTest, Broadcast) { ((bl_x * 128 + th_x) floordiv 30) mod 20, (bl_x * 128 + th_x) mod 30) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 47) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - bl_x * 128 + th_x in [0, 6000) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 128 + th_x in [0, 5999] )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( @@ -205,15 +205,15 @@ TEST_F(LoopTest, Broadcast) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (((bl_x * 128 + th_x) floordiv 30) mod 20) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 47) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - bl_x * 128 + th_x in [0, 6000) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 46] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 128 + th_x in [0, 5999] )")); } diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 3b03393788538a..6a27e548ca932f 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -236,9 +236,9 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 4)> - // CHECK-SAME: (%[[Y]] in [0, 3)) + // CHECK-SAME: (%[[Y]] in [0, 2]) // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0 - 3)> - // CHECK-SAME: (%[[Z]] in [0, 8))[%[[I]] in [0, 7)] + // CHECK-SAME: (%[[Z]] in [0, 7])[%[[I]] in [0, 6]] // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] // CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]] // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]], @@ -286,7 +286,7 @@ TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { // `s0 floordiv ` in the map: // CHECK: %[[K:.*]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[X]] in [0, 19))[%[[I]] in [0, 4)] + // CHECK-SAME: (%[[X]] in [0, 18])[%[[I]] in [0, 3]] // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] )")); @@ -433,7 +433,7 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 8)) + // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -446,10 +446,10 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)> - // CHECK-SAME: (%[[X]] in [1, 8)) + // CHECK-SAME: (%[[X]] in [1, 7]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing // CHECK-SAME: <(d0) -> (d0 - 4)> - // CHECK-SAME: (%[[Y]] in [4, 8)) + // CHECK-SAME: (%[[Y]] in [4, 7]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -477,7 +477,7 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 8)) + // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -490,10 +490,10 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)> - // CHECK-SAME: (%[[X]] in [1, 8)) + // CHECK-SAME: (%[[X]] in [1, 7]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing // CHECK-SAME: <(d0) -> (d0 - 4)> - // CHECK-SAME: (%[[Y]] in [4, 8)) + // CHECK-SAME: (%[[Y]] in [4, 7]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -811,10 +811,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 6))[%[[X]] in [0, 3)] + // CHECK-SAME: (%[[W]] in [0, 5])[%[[X]] in [0, 2]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 8))[%[[Y]] in [0, 5)] + // CHECK-SAME: (%[[H]] in [0, 7])[%[[Y]] in [0, 4]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -857,10 +857,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[W]] in [0, 3))[%[[X]] in [0, 3)] + // CHECK-SAME: (%[[W]] in [0, 2])[%[[X]] in [0, 2]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[H]] in [0, 4))[%[[Y]] in [0, 5)] + // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -903,21 +903,21 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[W]] in [0, 8))[%[[X]] in [0, 3)] + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[W]] in [0, 7])[%[[X]] in [0, 2]] // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[H]] in [0, 12))[%[[Y]] in [0, 5)] + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[H]] in [0, 11])[%[[Y]] in [0, 4]] // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 1)> - // CHECK-SAME: (%[[W]] in [0, 8))[%[[X]] in [0, 3)] + // CHECK-SAME: (%[[W]] in [0, 7])[%[[X]] in [0, 2]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 2)> - // CHECK-SAME: (%[[H]] in [0, 12))[%[[Y]] in [0, 5)] + // CHECK-SAME: (%[[H]] in [0, 11])[%[[Y]] in [0, 4]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -957,17 +957,17 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[W]] in [0, 13))[%[[X]] in [0, 3)] + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[W]] in [0, 12])[%[[X]] in [0, 2]] // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[H]] in [0, 19))[%[[Y]] in [0, 5)] + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[H]] in [0, 18])[%[[Y]] in [0, 4]] // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)> - // CHECK-SAME: (%[[W]] in [0, 13))[%[[X]] in [0, 3)] + // CHECK-SAME: (%[[W]] in [0, 12])[%[[X]] in [0, 2]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)> - // CHECK-SAME: (%[[H]] in [0, 19))[%[[Y]] in [0, 5)] + // CHECK-SAME: (%[[H]] in [0, 18])[%[[Y]] in [0, 4]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1010,10 +1010,10 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)> - // CHECK-SAME: (%[[W]] in [0, 4))[%[[X]] in [0, 3)] + // CHECK-SAME: (%[[W]] in [0, 3])[%[[X]] in [0, 2]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)> - // CHECK-SAME: (%[[H]] in [0, 4))[%[[Y]] in [0, 5)] + // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1056,16 +1056,16 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 6)) - // CHECK-SAME: [%[[X]] in [0, 3)] + // CHECK-SAME: (%[[W]] in [0, 5]) + // CHECK-SAME: [%[[X]] in [0, 2]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 8)) - // CHECK-SAME: [%[[Y]] in [0, 5)] + // CHECK-SAME: (%[[H]] in [0, 7]) + // CHECK-SAME: [%[[Y]] in [0, 4]] // CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0)> - // CHECK-SAME: (%[[O]] in [0, 16)) - // CHECK-SAME: [%[[I]] in [0, 2)] + // CHECK-SAME: (%[[O]] in [0, 15]) + // CHECK-SAME: [%[[I]] in [0, 1]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1110,12 +1110,12 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 6)) - // CHECK-SAME: [%[[X]] in [0, 3)] + // CHECK-SAME: (%[[W]] in [0, 5]) + // CHECK-SAME: [%[[X]] in [0, 2]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 8)) - // CHECK-SAME: [%[[Y]] in [0, 5)] + // CHECK-SAME: (%[[H]] in [0, 7]) + // CHECK-SAME: [%[[Y]] in [0, 4]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1582,7 +1582,7 @@ TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)> - // CHECK-SAME: (%[[X]] in [0, 10), %[[Y]] in [0, 10)) + // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9]) // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] // CHECK: return %[[A]], %[[B]] )")); @@ -1606,7 +1606,7 @@ TEST_F(ElementalHloToMlirTest, NestedTuple) { // CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0 // CHECK: %[[IDX:.*]] = // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)> - // CHECK-SAME: (%[[X]] in [0, 10), %[[Y]] in [0, 10)) + // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9]) // CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1 // CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]]) // CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]] diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc index 8bdc4c95985c9e..6f6662686553ea 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -255,12 +255,12 @@ mlir::ParseResult parseOperandsWithBoundsList( if (parser.parseOperand(operand) || parser.parseKeyword("in") || parser.parseLSquare() || parser.parseInteger(lower_bound) || parser.parseComma() || parser.parseInteger(upper_bound) || - parser.parseRParen()) { + parser.parseRSquare()) { return failure(); } operands->push_back(operand); lower_bounds->push_back(lower_bound); - upper_bounds->push_back(upper_bound - 1); + upper_bounds->push_back(upper_bound); return success(); })) { return failure(); @@ -321,7 +321,7 @@ void ApplyIndexingOp::print(mlir::OpAsmPrinter& p) { p << '('; for (int dim_id = 0; dim_id < num_dimensions; ++dim_id) { p << operands[dim_id] << " in " << '[' << lower_bounds[dim_id] << ", " - << upper_bounds[dim_id] + 1 << ')'; + << upper_bounds[dim_id] << ']'; if (dim_id != num_dimensions - 1) { p << ", "; } @@ -334,7 +334,7 @@ void ApplyIndexingOp::print(mlir::OpAsmPrinter& p) { for (int symbol_id = 0; symbol_id < num_symbols; ++symbol_id) { unsigned operand_id = num_dimensions + symbol_id; p << operands[operand_id] << " in " << '[' << lower_bounds[operand_id] - << ", " << upper_bounds[operand_id] + 1 << ')'; + << ", " << upper_bounds[operand_id] << ']'; if (symbol_id != num_symbols - 1) { p << ", "; } diff --git a/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir b/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir index bac04b4af5cefc..17b0f8d9b45c88 100644 --- a/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir +++ b/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir @@ -2,14 +2,14 @@ #map0 = affine_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)> func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [-10, 10), %s1 in [0, 3)] + %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [-10, 10], %s1 in [0, 2]] func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1, s0 mod 2)> // CHECK-LABEL: func.func @simplify_apply_indexing // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]] in [-10, 10)] +// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]] in [-10, 10]] // ----- @@ -17,8 +17,8 @@ func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, %d2: index, %s0: index, %s1: index) -> (index, index, index) { %0:3 = xla_gpu.apply_indexing #map0 - (%d0 in [0, 2), %d1 in [0, 3), %d2 in [0, 4)) - [%s0 in [-11, 11), %s1 in [0, 4)] + (%d0 in [0, 1], %d1 in [0, 2], %d2 in [0, 3]) + [%s0 in [-11, 11], %s1 in [0, 3]] func.return %0#0, %0#1, %0#2 : index, index, index } // CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1)> @@ -30,15 +30,15 @@ func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, // CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: index, // CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: index) // CHECK: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG_0]] in [0, 2), %[[ARG_2]] in [0, 4)) -// CHECK-SAME: [%[[ARG_3]] in [-11, 11)] +// CHECK-SAME: (%[[ARG_0]] in [0, 1], %[[ARG_2]] in [0, 3]) +// CHECK-SAME: [%[[ARG_3]] in [-11, 11]] // ----- #map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0)> func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) -> (index, index, index, index, index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10), %d1 in [0, 3))[%s0 in [-1, 2)] + %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]] func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index } // CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> @@ -56,7 +56,7 @@ func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) #map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)> func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10), %d1 in [0, 3))[%s0 in [-1, 2)] + %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]] func.return %0#2 : index } // CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> @@ -64,7 +64,7 @@ func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) // CHECK-LABEL: func.func @remove_unused_results // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]] in [0, 3)) +// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]] in [0, 2]) // CHECK: return %[[NEW_RESULT]] // ----- @@ -74,22 +74,22 @@ func.func @fold_operands(%d0: index) -> index { %d1 = arith.constant 1 : index %s0 = arith.constant 2 : index %s1 = arith.constant 3 : index - %0 = xla_gpu.apply_indexing #map0 (%d0 in [0, 11), %d1 in [0, 6)) - [%s0 in [-10, 10), %s1 in [0, 5)] + %0 = xla_gpu.apply_indexing #map0 (%d0 in [0, 10], %d1 in [0, 5]) + [%s0 in [-10, 10], %s1 in [0, 4]] func.return %0 : index } // CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 3)> // CHECK-LABEL: func.func @fold_operands // CHECK-SAME: %[[ARG_0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 11)) +// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 10]) // ----- func.func @fold_operands_and_results(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (0, d1)> - (%arg0 in [0, 5), %arg1 in [0, 6)) + (%arg0 in [0, 4], %arg1 in [0, 5]) return %0#0, %0#1 : index, index } @@ -102,9 +102,9 @@ func.func @fold_operands_and_results(%arg0: index, %arg1: index) func.func @fold_sequence(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 6), %arg1 in [0, 5)) + (%arg0 in [0, 5], %arg1 in [0, 4]) %1 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 100 + 42)> - (%0 in [0, 10001)) + (%0 in [0, 10000]) func.return %1 : index } @@ -112,15 +112,15 @@ func.func @fold_sequence(%arg0: index, %arg1: index) -> index { // CHECK-LABEL: func.func @fold_sequence // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG0]] in [0, 6), %[[ARG1]] in [0, 5)) +// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4]) // ----- func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 6), %arg1 in [0, 5)) + (%arg0 in [0, 5], %arg1 in [0, 4]) %1 = xla_gpu.apply_indexing affine_map<()[s0] -> (s0 mod 100 + 42)> - [%0 in [0, 10001)] + [%0 in [0, 10000]] func.return %1 : index } @@ -128,15 +128,15 @@ func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { // CHECK-LABEL: func.func @fold_sequence_sym // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG0]] in [0, 6), %[[ARG1]] in [0, 5)) +// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4]) // ----- func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 6), %arg1 in [0, 5)) + (%arg0 in [0, 5], %arg1 in [0, 4]) %1 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg1 in [0, 5), %0 in [0, 10001)) + (%arg1 in [0, 4], %0 in [0, 10000]) func.return %1 : index } @@ -144,7 +144,7 @@ func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { // CHECK-LABEL: func.func @fold_sequence_shared_operands // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG1]] in [0, 5), %[[ARG0]] in [0, 6)) +// CHECK-SAME: (%[[ARG1]] in [0, 4], %[[ARG0]] in [0, 5]) // ----- diff --git a/xla/service/gpu/fusions/mlir/tests/invalid.mlir b/xla/service/gpu/fusions/mlir/tests/invalid.mlir index 55aa9c78512034..fbef7c049db487 100644 --- a/xla/service/gpu/fusions/mlir/tests/invalid.mlir +++ b/xla/service/gpu/fusions/mlir/tests/invalid.mlir @@ -3,6 +3,6 @@ #map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{operand, lower_bounds, upper_bounds count and affine map dimension and symbol count must match}} - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 3)) + %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2]) func.return %0#0, %0#1 : index, index } diff --git a/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir index 5b46947767211b..2125e6f4d70c8f 100644 --- a/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir +++ b/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir @@ -95,7 +95,7 @@ module { // CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, // CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[X]] in [0, 2), %[[Y]] in [0, 3)) +// CHECK-SAME: (%[[X]] in [0, 1], %[[Y]] in [0, 2]) // CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64 // CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]] // CHECK: llvm.load %[[PTR]] diff --git a/xla/service/gpu/fusions/mlir/tests/ops.mlir b/xla/service/gpu/fusions/mlir/tests/ops.mlir index da3ef936395b72..c7f15073b5e0ed 100644 --- a/xla/service/gpu/fusions/mlir/tests/ops.mlir +++ b/xla/service/gpu/fusions/mlir/tests/ops.mlir @@ -58,7 +58,7 @@ func.func @caller(%a: f32, %b: f32) -> f32 { #map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 3), %d1 in [1, 4))[%s0 in [2, 5)] + %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3])[%s0 in [2, 4]] func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> @@ -66,13 +66,13 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // CHECK-LABEL: @apply_indexing // CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index) // CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]] in [0, 3), %[[d1]] in [1, 4))[%[[s0]] in [2, 5)] +// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3])[%[[s0]] in [2, 4]] // ----- #map0 = affine_map<(d0, d1) -> (d0, d1)> func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 3), %d1 in [1, 4)) + %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3]) func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> @@ -80,17 +80,17 @@ func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { // CHECK-LABEL: @apply_indexing_no_symbols // CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index) // CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]] in [0, 3), %[[d1]] in [1, 4)) +// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3]) // ----- #map0 = affine_map<()[s0] -> (s0, s0)> func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [2, 5)] + %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [2, 4]] func.return %0#0, %0#1 : index, index } // CHECK: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0, s0)> // CHECK-LABEL: @apply_indexing_no_dims // CHECK: (%[[s0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]] in [2, 5)] +// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]] in [2, 4]] diff --git a/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir b/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir index 1f173fa26b1d47..6f903f3ace4748 100644 --- a/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir +++ b/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir @@ -21,23 +21,23 @@ module { %1 = arith.cmpi eq, %0, %c0 : index %2 = arith.divui %thread_id_x, %c32 : index %3 = arith.cmpi ult, %thread_id_x, %c8 : index - %4 = xla_gpu.apply_indexing #map(%block_id_x in [0, 32)) - %5 = xla_gpu.apply_indexing #map1(%block_id_x in [0, 32)) + %4 = xla_gpu.apply_indexing #map(%block_id_x in [0, 31]) + %5 = xla_gpu.apply_indexing #map1(%block_id_x in [0, 31]) %extracted = tensor.extract %arg2[%4, %5] : tensor<4x8xf32> %6 = arith.mulf %extracted, %cst : f32 %7 = arith.addf %6, %cst : f32 %8 = math.rsqrt %7 : f32 %9:2 = scf.for %arg7 = %c0 to %c8 step %c1 iter_args(%arg8 = %arg6, %arg9 = %cst) -> (tensor<4x8x4096xf32>, f32) { - %18 = xla_gpu.apply_indexing #map2(%c0 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] + %18 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] %19 = vector.transfer_read %arg1[%18], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %20 = xla_gpu.apply_indexing #map2(%c0 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] + %20 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] %21 = vector.transfer_read %arg3[%20], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %22 = xla_gpu.apply_indexing #map2(%c0 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] + %22 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] %23 = vector.transfer_read %arg4[%4, %5, %22], %cst_1 {in_bounds = [true]} : tensor<4x8x4096xbf16>, vector<2xbf16> - %24 = xla_gpu.apply_indexing #map2(%c0 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] + %24 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] %25 = vector.transfer_read %arg0[%4, %5, %24], %cst {in_bounds = [true]} : tensor<4x8x4096xf32>, vector<2xf32> %26:2 = scf.for %arg10 = %c0 to %c2 step %c1 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<4x8x4096xf32>, f32) { - %27 = xla_gpu.apply_indexing #map2(%arg10 in [0, 2), %thread_id_x in [0, 256))[%arg7 in [0, 8)] + %27 = xla_gpu.apply_indexing #map2(%arg10 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] %28 = vector.extract %25[%arg10] : f32 from vector<2xf32> %29 = vector.extract %23[%arg10] : bf16 from vector<2xbf16> %30 = arith.extf %29 : bf16 to f32 @@ -151,7 +151,7 @@ module { %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> %cst0 = arith.constant 0.0 : f32 %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 2)>(%i in [0, 16)) + %base = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 2)>(%i in [0, 15]) %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> %log = math.log %val : vector<2xf32> %add = arith.addf %log, %iter : vector<2xf32> diff --git a/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir b/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir index acd6cc097f9e63..ec1a726da9db13 100644 --- a/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir +++ b/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir @@ -63,7 +63,7 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %1 = gpu.block_id x scf.for %i = %c0 to %c4 step %c1 { %2 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))> - [%1 in [0, 3072), %0 in [0, 128), %i in [0, 4)] + [%1 in [0, 3071], %0 in [0, 127], %i in [0, 3]] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.load %4 invariant : !llvm.ptr -> f32 @@ -92,7 +92,7 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt func.func @arg_ranges(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)> - [%arg0 in [0, 43), %arg1 in [0, 1001)] + [%arg0 in [0, 42], %arg1 in [0, 1000]] return %0 : index } @@ -106,7 +106,7 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)> - [%arg0 in [-10, 43), %arg1 in [0, 1001)] + [%arg0 in [-10, 42], %arg1 in [0, 1000]] return %0#0, %0#1 : index, index } @@ -124,7 +124,7 @@ func.func @order_summands(%arg1: index) { scf.for %arg3 = %c0 to %c4 step %c1 { %0 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)> - [%arg2 in [0, 4), %arg1 in [0, 4), %arg3 in [0, 4)] + [%arg2 in [0, 3], %arg1 in [0, 3], %arg3 in [0, 3]] "dummy.op"(%0) : (index) -> () } } diff --git a/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir b/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir index fb944a5dcb923d..1141d1581505ea 100644 --- a/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir +++ b/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir @@ -10,7 +10,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] + %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -29,7 +29,7 @@ module { // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #map(%[[I]] in [0, 64)) +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #map(%[[I]] in [0, 63]) // CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] // CHECK-NEXT: vector.extract %[[V]][%[[J]]] @@ -76,7 +76,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] + %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -102,7 +102,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] + %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -152,7 +152,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] + %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -312,7 +312,7 @@ module { %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32> %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<32x4096xf32>, f32) { %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<32x4096xf32>, f32) { - %2 = xla_gpu.apply_indexing #map(%j in [0, 2), %arg4 in [0, 256))[%i in [0, 8)] + %2 = xla_gpu.apply_indexing #map(%j in [0, 1], %arg4 in [0, 255])[%i in [0, 7]] %extracted2 = tensor.extract %arg0[%i, %2] : tensor<32x4096xf32> %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16> %3 = arith.extf %extracted3 : bf16 to f32 @@ -333,7 +333,7 @@ module { // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]] in [0, 256))[%[[I]] in [0, 8)] +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]] in [0, 255])[%[[I]] in [0, 7]] // CHECK: %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]] // CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[BASE]]] // CHECK: %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>) @@ -360,7 +360,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] + %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -390,7 +390,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 64))[%j in [0, 2)] + %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 diff --git a/xla/service/gpu/fusions/reduction_mlir_test.cc b/xla/service/gpu/fusions/reduction_mlir_test.cc index 214e9b582cb123..c583088a245beb 100644 --- a/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -408,10 +408,10 @@ TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) { // CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] // CHECK-NOT: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 2), %thread_id_x in [0, 256))[%[[I]] in [0, 5)] + // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 1], %thread_id_x in [0, 255])[%[[I]] in [0, 4]] // CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]]) // CHECK: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]] in [0, 2), %thread_id_x in [0, 256)) + // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]] in [0, 1], %thread_id_x in [0, 255]) )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -691,16 +691,16 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) { (d3 mod 11) * 32 + d0 mod 32 ) domain: - d0 in [0, 1024) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 143) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 33) - s1 in [0, 1) - (d3 mod 11) * 32 + d0 mod 32 in [0, 321) - d0 floordiv 32 + s0 * 32 in [0, 1051) + d0 in [0, 1023] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 142] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 32] + s1 in [0, 0] + (d3 mod 11) * 32 + d0 mod 32 in [0, 320] + d0 floordiv 32 + s0 * 32 in [0, 1050] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -709,15 +709,15 @@ TEST_F(MlirColumnReductionTest, ColumnReduction) { d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32 ) domain: - d0 in [0, 993) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 143) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 1) - (d3 mod 11) * 32 + d0 floordiv 32 in [0, 321) - d0 mod 32 in [0, 1) + d0 in [0, 992] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 142] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + (d3 mod 11) * 32 + d0 floordiv 32 in [0, 320] + d0 mod 32 in [0, 0] )")); TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( // CHECK: xla_gpu.pure_call @Add_add diff --git a/xla/service/gpu/fusions/reduction_test.cc b/xla/service/gpu/fusions/reduction_test.cc index 520f8cb9d7f02f..81649a735e8329 100644 --- a/xla/service/gpu/fusions/reduction_test.cc +++ b/xla/service/gpu/fusions/reduction_test.cc @@ -81,16 +81,16 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { (d0 mod 32) * 2 + s2 * 64 + s3 ) domain: - d0 in [0, 256) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 800) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 1) - s1 in [0, 1) - s2 in [0, 8) - s3 in [0, 2) + d0 in [0, 255] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 0] + s1 in [0, 0] + s2 in [0, 7] + s3 in [0, 1] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -100,13 +100,13 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { (d3 mod 8) * 8 + d0 floordiv 32 ) domain: - d0 in [0, 225) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 800) - d4 in [0, 1) - d5 in [0, 1) - d0 mod 32 in [0, 1) + d0 in [0, 224] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 799] + d4 in [0, 0] + d5 in [0, 0] + d0 mod 32 in [0, 0] )")); } diff --git a/xla/service/gpu/fusions/scatter_mlir_test.cc b/xla/service/gpu/fusions/scatter_mlir_test.cc index a112629ebee9e5..869d2335001825 100644 --- a/xla/service/gpu/fusions/scatter_mlir_test.cc +++ b/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -87,15 +87,15 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { (bl_x * 128 + th_x) mod 20 ) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 66) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - bl_x * 128 + th_x in [0, 8400) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 65] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 128 + th_x in [0, 8399] )"; EXPECT_THAT( fusion @@ -126,16 +126,16 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ((bl_x * 128 + th_x) floordiv 200, 0) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 66) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - index_id in [0, 1) - bl_x * 128 + th_x in [0, 8400) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 65] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + index_id in [0, 0] + bl_x * 128 + th_x in [0, 8399] )"; EXPECT_THAT( fusion diff --git a/xla/service/gpu/fusions/scatter_test.cc b/xla/service/gpu/fusions/scatter_test.cc index 3dee8912fc06a3..284d308ad5a190 100644 --- a/xla/service/gpu/fusions/scatter_test.cc +++ b/xla/service/gpu/fusions/scatter_test.cc @@ -155,15 +155,15 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { (bl_x * 128 + th_x) mod 20 ) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 66) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - bl_x * 128 + th_x in [0, 8400) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 65] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 128 + th_x in [0, 8399] )"; EXPECT_THAT( fusion @@ -194,16 +194,16 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ((bl_x * 128 + th_x) floordiv 200, 0) domain: - th_x in [0, 128) - th_y in [0, 1) - th_z in [0, 1) - bl_x in [0, 66) - bl_y in [0, 1) - bl_z in [0, 1) - chunk_id in [0, 1) - unroll_id in [0, 1) - index_id in [0, 1) - bl_x * 128 + th_x in [0, 8400) + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 65] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + index_id in [0, 0] + bl_x * 128 + th_x in [0, 8399] )"; EXPECT_THAT( fusion diff --git a/xla/service/gpu/fusions/transpose_mlir_test.cc b/xla/service/gpu/fusions/transpose_mlir_test.cc index dd06b695fdfdc8..1861672a82279d 100644 --- a/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -56,15 +56,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 8) - s1 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 7] + s1 in [0, 0] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -75,15 +75,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 8) - s1 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 7] + s1 in [0, 0] )")); } @@ -113,15 +113,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 8) - s1 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 7] + s1 in [0, 0] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -132,15 +132,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 8) - s1 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 7] + s1 in [0, 0] )")); } @@ -170,14 +170,14 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized021) { (d0 mod 32) * 2 + s1 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 8192) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 16) - s1 in [0, 2) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 8191] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 15] + s1 in [0, 1] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -188,14 +188,14 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized021) { (d0 mod 32) * 2 + s1 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 8192) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 16) - s1 in [0, 2) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 8191] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 15] + s1 in [0, 1] )")); } @@ -224,14 +224,14 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized210) { (d0 mod 32) * 2 + (d3 mod 128) * 64 + s1 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 8192) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 16) - s1 in [0, 2) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 8191] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 15] + s1 in [0, 1] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), @@ -242,14 +242,14 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized210) { (d0 mod 32) * 2 + s1 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 8192) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 16) - s1 in [0, 2) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 8191] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 15] + s1 in [0, 1] )")); } @@ -621,15 +621,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingSideOutput) { d0 floordiv 32 + s0 * 4 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 8) - s1 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 7] + s1 in [0, 0] )")); EXPECT_THAT( fusion.ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), @@ -640,15 +640,15 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingSideOutput) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 8) - s1 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 7] + s1 in [0, 0] )")); } diff --git a/xla/service/gpu/fusions/transpose_test.cc b/xla/service/gpu/fusions/transpose_test.cc index f0998eec4ccb01..f94246916406c9 100644 --- a/xla/service/gpu/fusions/transpose_test.cc +++ b/xla/service/gpu/fusions/transpose_test.cc @@ -85,16 +85,16 @@ TEST_F(TransposeTest, ThreadIndexing021) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 1) - s1 in [0, 8) - s2 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] )")); EXPECT_THAT( fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), @@ -105,16 +105,16 @@ TEST_F(TransposeTest, ThreadIndexing021) { d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 1) - s1 in [0, 8) - s2 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] )")); } @@ -147,16 +147,16 @@ TEST_F(TransposeTest, ThreadIndexing201) { d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 1) - s1 in [0, 8) - s2 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] )")); EXPECT_THAT( fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), @@ -167,16 +167,16 @@ TEST_F(TransposeTest, ThreadIndexing201) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 1) - s1 in [0, 8) - s2 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] )")); } @@ -212,16 +212,16 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { d0 mod 4 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 2) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 6) - s1 in [0, 1) - s2 in [0, 1) - d0 mod 32 in [0, 24) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 5] + s1 in [0, 0] + s2 in [0, 0] + d0 mod 32 in [0, 23] )")); EXPECT_THAT( fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), @@ -233,16 +233,16 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 2) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 6) - s1 in [0, 1) - s2 in [0, 1) - d0 mod 32 in [0, 24) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 5] + s1 in [0, 0] + s2 in [0, 0] + d0 mod 32 in [0, 23] )")); } @@ -308,16 +308,16 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { d0 floordiv 32 + s1 * 4 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 1) - s1 in [0, 8) - s2 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] )")); EXPECT_THAT( fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), @@ -328,16 +328,16 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { (d3 mod 2) * 32 + d0 mod 32 ) domain: - d0 in [0, 128) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 200) - d4 in [0, 1) - d5 in [0, 1) - - s0 in [0, 1) - s1 in [0, 8) - s2 in [0, 1) + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 199] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 0] + s1 in [0, 7] + s2 in [0, 0] )")); } diff --git a/xla/service/gpu/model/coalescing_analysis.cc b/xla/service/gpu/model/coalescing_analysis.cc index 2f707e6b589e2b..9e7a685d590a29 100644 --- a/xla/service/gpu/model/coalescing_analysis.cc +++ b/xla/service/gpu/model/coalescing_analysis.cc @@ -327,9 +327,9 @@ std::optional Partition(AffineExpr expr) { // For example, for the following indexing map: // (d0)[s0] -> (d0 + s0) // domain: -// d0 in [0, 4) +// d0 in [0, 3] // s0 in [0, 1, 2] -// s0 mod 2 in [0, 1) +// s0 mod 2 in [0, 0] // The function will compute the following indices [0, 2, 1, 3, 2, 4, 3, 5]. void FindAllIndices(AffineExpr expr, int dim_id, int symbol_id, const std::vector& dimension_ranges, @@ -365,8 +365,8 @@ void FindAllIndices(AffineExpr expr, int dim_id, int symbol_id, // Computes contiguous intervals of accessed elements. // For example, for an indexing map // (thread_x) -> (thread_x * 4 + s0 + (thread_x floordiv 16) * 1984) -// d0 in [0, 32) -// s0 in [0, 4) +// d0 in [0, 31] +// s0 in [0, 3] // The intervals are [0, 63] and [2047, 2111]. std::vector FindIntervals( AffineExpr expr, const std::vector& dimension_ranges, diff --git a/xla/service/gpu/model/indexing_analysis_test.cc b/xla/service/gpu/model/indexing_analysis_test.cc index 122461fd255cbc..30fd8056697498 100644 --- a/xla/service/gpu/model/indexing_analysis_test.cc +++ b/xla/service/gpu/model/indexing_analysis_test.cc @@ -65,14 +65,14 @@ TEST_F(IndexingAnalysisTest, FuseProducerConsumerOutputToInputIndexing) { UnorderedElementsAre(Pair(parameter, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )"))))); } @@ -97,26 +97,26 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing) { Pair(root, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )"))), Pair(parameter, UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )"))))); } @@ -155,29 +155,29 @@ TEST_F(IndexingAnalysisTest, Pair(root, ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 32) + d0 in [0, 31] )"))), Pair(root->operand(0), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0) domain: - d0 in [0, 32) - s0 in [0, 40) + d0 in [0, 31] + s0 in [0, 39] )"))), Pair(root->operand(1), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0) domain: - d0 in [0, 32) - s0 in [0, 40) + d0 in [0, 31] + s0 in [0, 39] )"))), Pair(root->operand(2), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 32) + d0 in [0, 31] )"))), Pair(root->operand(3), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 32) + d0 in [0, 31] )"))))); } @@ -206,8 +206,8 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing_SingleOp) { parameter, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )"))))); } @@ -248,18 +248,18 @@ TEST_F(IndexingAnalysisTest, Pair(&bcast.instruction(), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d1, d2, d3) domain: - d0 in [0, 15) - d1 in [0, 32) - d2 in [0, 20) - d3 in [0, 64) + d0 in [0, 14] + d1 in [0, 31] + d2 in [0, 19] + d3 in [0, 63] )"))), Pair(¶meter_0.instruction(), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2) domain: - d0 in [0, 15) - d1 in [0, 32) - d2 in [0, 20) - d3 in [0, 64) + d0 in [0, 14] + d1 in [0, 31] + d2 in [0, 19] + d3 in [0, 63] )"))))); } @@ -277,9 +277,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1, d2, d0) domain: - d0 in [0, 30) - d1 in [0, 10) - d2 in [0, 20) + d0 in [0, 29] + d1 in [0, 9] + d2 in [0, 19] )")))); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -288,9 +288,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d2, d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) - d2 in [0, 30) + d0 in [0, 9] + d1 in [0, 19] + d2 in [0, 29] )")))); } @@ -351,9 +351,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d2, d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) - d2 in [0, 30) + d0 in [0, 9] + d1 in [0, 19] + d2 in [0, 29] )")))); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -362,9 +362,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1, d2, d0) domain: - d0 in [0, 30) - d1 in [0, 10) - d2 in [0, 20) + d0 in [0, 29] + d1 in [0, 9] + d2 in [0, 19] )")))); } @@ -382,9 +382,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 30) - d1 in [0, 10) - d2 in [0, 20) + d0 in [0, 29] + d1 in [0, 9] + d2 in [0, 19] )")))); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -393,9 +393,9 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 30) - d1 in [0, 10) - d2 in [0, 20) + d0 in [0, 29] + d1 in [0, 9] + d2 in [0, 19] )")))); } @@ -413,14 +413,14 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) + d0 in [0, 9] + d1 in [0, 19] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) + d0 in [0, 9] + d1 in [0, 19] )")))); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -428,8 +428,8 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) + d0 in [0, 9] + d1 in [0, 19] )")))); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -437,8 +437,8 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) + d0 in [0, 9] + d1 in [0, 19] )")))); } @@ -461,14 +461,14 @@ TEST_F(IndexingAnalysisTest, Map) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) + d0 in [0, 9] + d1 in [0, 19] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) + d0 in [0, 9] + d1 in [0, 19] )")))); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -476,8 +476,8 @@ TEST_F(IndexingAnalysisTest, Map) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) + d0 in [0, 9] + d1 in [0, 19] )")))); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -485,8 +485,8 @@ TEST_F(IndexingAnalysisTest, Map) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 10) - d1 in [0, 20) + d0 in [0, 9] + d1 in [0, 19] )")))); } @@ -502,9 +502,9 @@ TEST_F(IndexingAnalysisTest, BitcastIsReshape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 4 + d2) domain: - d0 in [0, 4) - d1 in [0, 8) - d2 in [0, 4) + d0 in [0, 3] + d1 in [0, 7] + d2 in [0, 3] )")))); } @@ -520,10 +520,10 @@ TEST_F(IndexingAnalysisTest, BitcastIsTranspose) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: - d0 in [0, 3) - d1 in [0, 6) - d2 in [0, 128) - d3 in [0, 12288) + d0 in [0, 2] + d1 in [0, 5] + d2 in [0, 127] + d3 in [0, 12287] )")))); } @@ -540,17 +540,17 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d1, d0 floordiv 3, d0 mod 3) domain: - d0 in [0, 51) - d1 in [0, 16) + d0 in [0, 50] + d1 in [0, 15] )")))); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.indexing_maps, ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1 * 3 + d2, d0) domain: - d0 in [0, 16) - d1 in [0, 17) - d2 in [0, 3) + d0 in [0, 15] + d1 in [0, 16] + d2 in [0, 2] )")))); } @@ -567,9 +567,9 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d1) domain: - d0 in [0, 10) - d1 in [0, 20) - d2 in [0, 30) + d0 in [0, 9] + d1 in [0, 19] + d2 in [0, 29] )")))); auto output_indexing = GetInputToOutputIndexing(root); @@ -577,9 +577,9 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0, s1] -> (s0, d0, s1) domain: - d0 in [0, 20) - s0 in [0, 10) - s1 in [0, 30) + d0 in [0, 19] + s0 in [0, 9] + s1 in [0, 29] )")))); } @@ -610,23 +610,23 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 2) - d1 in [0, 5) - d2 in [0, 7) + d0 in [0, 1] + d1 in [0, 4] + d2 in [0, 6] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 - 5, d2) domain: - d0 in [0, 2) - d1 in [5, 16) - d2 in [0, 7) + d0 in [0, 1] + d1 in [5, 15] + d2 in [0, 6] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 - 16, d2) domain: - d0 in [0, 2) - d1 in [16, 33) - d2 in [0, 7) + d0 in [0, 1] + d1 in [16, 32] + d2 in [0, 6] )")))); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -634,9 +634,9 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 2) - d1 in [0, 5) - d2 in [0, 7) + d0 in [0, 1] + d1 in [0, 4] + d2 in [0, 6] )")))); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -644,9 +644,9 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 + 5, d2) domain: - d0 in [0, 2) - d1 in [0, 11) - d2 in [0, 7) + d0 in [0, 1] + d1 in [0, 10] + d2 in [0, 6] )")))); auto output_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); @@ -654,9 +654,9 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 + 16, d2) domain: - d0 in [0, 2) - d1 in [0, 17) - d2 in [0, 7) + d0 in [0, 1] + d1 in [0, 16] + d2 in [0, 6] )")))); } @@ -677,39 +677,39 @@ TEST_F(IndexingAnalysisTest, DynamicSliceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2) domain: - d0 in [0, 1) - d1 in [0, 2) - d2 in [0, 32) - s0 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] + d2 in [0, 31] + s0 in [0, 1] hlo: %of1 = s32[] parameter(1) (d0, d1, d2) -> () - s1 in [0, 1) + s1 in [0, 0] hlo: %of2 = s32[] parameter(2) (d0, d1, d2) -> () - s2 in [0, 227) + s2 in [0, 226] hlo: %of3 = s32[] parameter(3) (d0, d1, d2) -> () )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> () domain: - d0 in [0, 1) - d1 in [0, 2) - d2 in [0, 32) + d0 in [0, 0] + d1 in [0, 1] + d2 in [0, 31] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> () domain: - d0 in [0, 1) - d1 in [0, 2) - d2 in [0, 32) + d0 in [0, 0] + d1 in [0, 1] + d2 in [0, 31] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> () domain: - d0 in [0, 1) - d1 in [0, 2) - d2 in [0, 32) + d0 in [0, 0] + d1 in [0, 1] + d2 in [0, 31] )")))); } @@ -729,32 +729,32 @@ TEST_F(IndexingAnalysisTest, DynamicUpdateSliceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 20) - d1 in [0, 30) + d0 in [0, 19] + d1 in [0, 29] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 - s0, d1 - s1) domain: - d0 in [0, 20) - d1 in [0, 30) - s0 in [0, 16) + d0 in [0, 19] + d1 in [0, 29] + s0 in [0, 15] hlo: %of1 = s32[] parameter(2) (d0, d1) -> () - s1 in [0, 21) + s1 in [0, 20] hlo: %of2 = s32[] parameter(3) (d0, d1) -> () )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 20) - d1 in [0, 30) + d0 in [0, 19] + d1 in [0, 29] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 20) - d1 in [0, 30) + d0 in [0, 19] + d1 in [0, 29] )")))); } @@ -776,12 +776,12 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSingleBinaryOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 100) + d0 in [0, 99] )")), ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 100) + d0 in [0, 99] )")))); } @@ -849,66 +849,66 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d2, d0 * 768 + s0, d4, d5) domain: - d0 in [0, 16) - d1 in [0, 16) - d2 in [0, 3) - d3 in [0, 1) - d4 in [0, 6) - d5 in [0, 128) - s0 in [0, 768) + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + s0 in [0, 767] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0) domain: - d0 in [0, 16) - d1 in [0, 16) - d2 in [0, 3) - d3 in [0, 1) - d4 in [0, 6) - d5 in [0, 128) - s0 in [0, 768) + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + s0 in [0, 767] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5) -> (d1) domain: - d0 in [0, 16) - d1 in [0, 16) - d2 in [0, 3) - d3 in [0, 1) - d4 in [0, 6) - d5 in [0, 128) + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0) domain: - d0 in [0, 16) - d1 in [0, 16) - d2 in [0, 3) - d3 in [0, 1) - d4 in [0, 6) - d5 in [0, 128) - s0 in [0, 768) + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + s0 in [0, 767] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0) domain: - d0 in [0, 16) - d1 in [0, 16) - d2 in [0, 3) - d3 in [0, 1) - d4 in [0, 6) - d5 in [0, 128) - s0 in [0, 768) + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] + s0 in [0, 767] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5) -> (d2, d4, d5) domain: - d0 in [0, 16) - d1 in [0, 16) - d2 in [0, 3) - d3 in [0, 1) - d4 in [0, 6) - d5 in [0, 128) + d0 in [0, 15] + d1 in [0, 15] + d2 in [0, 2] + d3 in [0, 0] + d4 in [0, 5] + d5 in [0, 127] )")))); } @@ -962,17 +962,17 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSoftmax) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1, d2)[s0] -> (d0, d1, s0) domain: - d0 in [0, 2) - d1 in [0, 65) - d2 in [0, 125) - s0 in [0, 125) + d0 in [0, 1] + d1 in [0, 64] + d2 in [0, 124] + s0 in [0, 124] )"), MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 2) - d1 in [0, 65) - d2 in [0, 125) + d0 in [0, 1] + d1 in [0, 64] + d2 in [0, 124] )")))); } @@ -993,14 +993,14 @@ TEST_F(IndexingAnalysisTest, FusionOpTensorPlusTransposedTensor) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )")))); } @@ -1030,32 +1030,32 @@ TEST_F(IndexingAnalysisTest, FusionExponentialDuplication) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 1) domain: - d0 in [0, 2) + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 2) + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0 + 2) domain: - d0 in [0, 2) + d0 in [0, 1] )")), UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 2) domain: - d0 in [0, 2) + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0 + 1) domain: - d0 in [0, 2) + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 2) + d0 in [0, 1] )")))); } @@ -1074,25 +1074,25 @@ TEST_F(IndexingAnalysisTest, GatherOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1] -> (d1 + s0, d2 + s1, d3) domain: - d0 in [0, 1806) - d1 in [0, 7) - d2 in [0, 8) - d3 in [0, 4) - s0 in [0, 27) + d0 in [0, 1805] + d1 in [0, 6] + d2 in [0, 7] + d3 in [0, 3] + s0 in [0, 26] hlo: %indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 0) - s1 in [0, 69) + s1 in [0, 68] hlo: %indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 1) )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0] -> (d0, s0) domain: - d0 in [0, 1806) - d1 in [0, 7) - d2 in [0, 8) - d3 in [0, 4) - s0 in [0, 2) + d0 in [0, 1805] + d1 in [0, 6] + d2 in [0, 7] + d3 in [0, 3] + s0 in [0, 1] )")))); } @@ -1122,15 +1122,15 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfReduce) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s0, s2, d0, s1) domain: - d0 in [0, 10) - s0 in [0, 150) - s1 in [0, 50) - s2 in [0, 20) + d0 in [0, 9] + s0 in [0, 149] + s1 in [0, 49] + s2 in [0, 19] )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 10) + d0 in [0, 9] )")))); } @@ -1160,15 +1160,15 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfBroadcast) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0] -> (d0, s0) domain: - d0 in [0, 15) - d1 in [0, 64) - s0 in [0, 20) + d0 in [0, 14] + d1 in [0, 63] + s0 in [0, 19] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 15) - d1 in [0, 64) + d0 in [0, 14] + d1 in [0, 63] )")))); } @@ -1201,9 +1201,9 @@ TEST_F(IndexingAnalysisTest, FusionOpWithTransposeOfTranspose) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d2, d0, d1) domain: - d0 in [0, 10) - d1 in [0, 50) - d2 in [0, 20) + d0 in [0, 9] + d1 in [0, 49] + d2 in [0, 19] )")))); } @@ -1233,14 +1233,14 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReducedSlice) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50) domain: - d0 in [0, 32) - s0 in [0, 16) - s1 in [0, 128) + d0 in [0, 31] + s0 in [0, 15] + s1 in [0, 127] )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 32) + d0 in [0, 31] )")))); } @@ -1261,7 +1261,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 128) + d0 in [0, 127] )")))); } @@ -1282,8 +1282,8 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 8) - d1 in [0, 16) + d0 in [0, 7] + d1 in [0, 15] )")))); } @@ -1304,9 +1304,9 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 10) - d1 in [0, 10) - d2 in [0, 10) + d0 in [0, 9] + d1 in [0, 9] + d2 in [0, 9] )")))); } @@ -1331,9 +1331,9 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSliceOfSlice) { d1 * 6 + 8, d2 * 12 + 65) domain: - d0 in [0, 7) - d1 in [0, 9) - d2 in [0, 24) + d0 in [0, 6] + d1 in [0, 8] + d2 in [0, 23] )")))); } @@ -1367,44 +1367,44 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDynSliceOfDynSlice) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1, s2, s3] -> (d0 + s0 + s2, d1 + s1 + s3) domain: - d0 in [0, 25) - d1 in [0, 16) - s0 in [0, 101) + d0 in [0, 24] + d1 in [0, 15] + s0 in [0, 100] hlo: %of11 = s32[] parameter(1) (d0, d1) -> () - s1 in [0, 33) + s1 in [0, 32] hlo: %of12 = s32[] parameter(2) (d0, d1) -> () - s2 in [0, 26) + s2 in [0, 25] hlo: %of21 = s32[] parameter(3) (d0, d1) -> () - s3 in [0, 17) + s3 in [0, 16] hlo: %of22 = s32[] parameter(4) (d0, d1) -> () )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 25) - d1 in [0, 16) + d0 in [0, 24] + d1 in [0, 15] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 25) - d1 in [0, 16) + d0 in [0, 24] + d1 in [0, 15] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 25) - d1 in [0, 16) + d0 in [0, 24] + d1 in [0, 15] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 25) - d1 in [0, 16) + d0 in [0, 24] + d1 in [0, 15] )")))); } @@ -1431,23 +1431,23 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfAllConcatenateOpInputs) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 3, d2) domain: - d0 in [0, 2) - d1 in [0, 2) - d2 in [0, 7) + d0 in [0, 1] + d1 in [0, 1] + d2 in [0, 6] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 3 - 5, d2) domain: - d0 in [0, 2) - d1 in [2, 6) - d2 in [0, 7) + d0 in [0, 1] + d1 in [2, 5] + d2 in [0, 6] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 3 - 16, d2) domain: - d0 in [0, 2) - d1 in [6, 11) - d2 in [0, 7) + d0 in [0, 1] + d1 in [6, 10] + d2 in [0, 6] )")))); } @@ -1474,9 +1474,9 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfOneOfConcatenateOpInputs) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1 * 2, d2) domain: - d0 in [0, 2) - d1 in [0, 3) - d2 in [0, 7) + d0 in [0, 1] + d1 in [0, 2] + d2 in [0, 6] )")), ElementsAre(MatchIndexingMap("KNOWN EMPTY")), ElementsAre(MatchIndexingMap("KNOWN EMPTY")))); @@ -1501,16 +1501,16 @@ TEST_F(IndexingAnalysisTest, FusionOpReshapeOfConcat) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 8 + d1) domain: - d0 in [0, 4) - d1 in [0, 8) - d0 * 8 + d1 in [0, 2) + d0 in [0, 3] + d1 in [0, 7] + d0 * 8 + d1 in [0, 1] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 8 + d1 - 2) domain: - d0 in [0, 4) - d1 in [0, 8) - d0 * 8 + d1 in [2, 32) + d0 in [0, 3] + d1 in [0, 7] + d0 * 8 + d1 in [2, 31] )")))); } @@ -1537,7 +1537,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpCollapseShape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> (d0 floordiv 8, d0 mod 8) domain: - d0 in [0, 32) + d0 in [0, 31] )")))); } @@ -1553,8 +1553,8 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandShape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 8 + d1) domain: - d0 in [0, 4) - d1 in [0, 8) + d0 in [0, 3] + d1 in [0, 7] )")))); } @@ -1571,9 +1571,9 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 floordiv 8, d0 mod 8, d1 * 4 + d2) domain: - d0 in [0, 32) - d1 in [0, 3) - d2 in [0, 4) + d0 in [0, 31] + d1 in [0, 2] + d2 in [0, 3] )")))); auto output_indexing = GetInputToOutputIndexing(root); @@ -1581,9 +1581,9 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 * 8 + d1, d2 floordiv 4, d2 mod 4) domain: - d0 in [0, 4) - d1 in [0, 8) - d2 in [0, 12) + d0 in [0, 3] + d1 in [0, 7] + d2 in [0, 11] )")))); } @@ -1599,9 +1599,9 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandSubshapeOnly) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 * 4 + d1, d2) domain: - d0 in [0, 4) - d1 in [0, 4) - d2 in [0, 8) + d0 in [0, 3] + d1 in [0, 3] + d2 in [0, 7] )")))); } @@ -1618,9 +1618,9 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape2DTo3D) { (d0, d1, d2) -> (d0 * 2 + d1 floordiv 2, (d1 mod 2) * 4 + d2) domain: - d0 in [0, 2) - d1 in [0, 4) - d2 in [0, 4) + d0 in [0, 1] + d1 in [0, 3] + d2 in [0, 3] )")))); } @@ -1638,8 +1638,8 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape3DTo2D) { (d0 mod 2) * 2 + d1 floordiv 4, d1 mod 4) domain: - d0 in [0, 4) - d1 in [0, 8) + d0 in [0, 3] + d1 in [0, 7] )")))); } @@ -1659,15 +1659,15 @@ TEST_F(IndexingAnalysisTest, PadOp) { d1 - 4 ) domain: - d0 in [1, 8) - d1 in [4, 8) - (d0 - 1) mod 2 in [0, 1) + d0 in [1, 7] + d1 in [4, 7] + (d0 - 1) mod 2 in [0, 0] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 12) - d1 in [0, 16) + d0 in [0, 11] + d1 in [0, 15] )")))); } @@ -1684,14 +1684,14 @@ TEST_F(IndexingAnalysisTest, PadOpNoInterior) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 - 1, d1) domain: - d0 in [1, 3) - d1 in [0, 8) + d0 in [1, 2] + d1 in [0, 7] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 10) - d1 in [0, 8) + d0 in [0, 9] + d1 in [0, 7] )")))); } @@ -1713,13 +1713,13 @@ TEST_F(IndexingAnalysisTest, PadOpNegativePadding) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0) -> ((d0 + 3) floordiv 2) domain: - d0 in [0, 5) - (d0 + 3) mod 2 in [0, 1) + d0 in [0, 4] + (d0 + 3) mod 2 in [0, 0] )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 5) + d0 in [0, 4] )")))); } @@ -1743,16 +1743,16 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0, s0, d1, s1) domain: - d0 in [0, 150) - d1 in [0, 10) - s0 in [0, 20) - s1 in [0, 50) + d0 in [0, 149] + d1 in [0, 9] + s0 in [0, 19] + s1 in [0, 49] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 150) - d1 in [0, 10) + d0 in [0, 149] + d1 in [0, 9] )")))); auto output_indexing_0 = GetInputToOutputIndexing(root, 0); @@ -1760,18 +1760,18 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2) domain: - d0 in [0, 150) - d1 in [0, 20) - d2 in [0, 10) - d3 in [0, 50) + d0 in [0, 149] + d1 in [0, 19] + d2 in [0, 9] + d3 in [0, 49] )")))); auto output_indexing_1 = GetInputToOutputIndexing(root, 1); EXPECT_THAT(output_indexing_1.indexing_maps, ElementsAre(ElementsAre(MatchIndexingMap(R"( ()[s0, s1] -> (s0, s1) domain: - s0 in [0, 150) - s1 in [0, 10) + s0 in [0, 149] + s1 in [0, 9] )")))); } @@ -1803,26 +1803,26 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: - d0 in [0, 10) - s0 in [0, 256) + d0 in [0, 9] + s0 in [0, 255] )")), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: - d0 in [0, 10) - s0 in [0, 256) + d0 in [0, 9] + s0 in [0, 255] )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 10) + d0 in [0, 9] )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 10) + d0 in [0, 9] )")))); auto output_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); @@ -1830,31 +1830,31 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: - d0 in [0, 10) - s0 in [0, 256) + d0 in [0, 9] + s0 in [0, 255] )")), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (s0, d0) domain: - d0 in [0, 10) - s0 in [0, 256) + d0 in [0, 9] + s0 in [0, 255] )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 10) + d0 in [0, 9] )")), ElementsAre(MatchIndexingMap(R"( (d0) -> () domain: - d0 in [0, 10) + d0 in [0, 9] )")))); constexpr std::string_view kInputToOutputIndexing = R"( (d0, d1) -> (d1) domain: - d0 in [0, 256) - d1 in [0, 10) + d0 in [0, 255] + d1 in [0, 9] )"; auto input_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); EXPECT_THAT( @@ -1871,7 +1871,7 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { constexpr std::string_view kInitToOutputIndexing = R"( ()[s0] -> (s0) domain: - s0 in [0, 10) + s0 in [0, 9] )"; auto input_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); EXPECT_THAT( @@ -1905,15 +1905,15 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_NoPadding) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0] -> (d0, d1 + s0) domain: - d0 in [0, 1024) - d1 in [0, 3) - s0 in [0, 512) + d0 in [0, 1023] + d1 in [0, 2] + s0 in [0, 511] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 1024) - d1 in [0, 3) + d0 in [0, 1023] + d1 in [0, 2] )")))); } @@ -1937,18 +1937,18 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_PaddingAndWindowStride) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 * 2 + s0 - 1, d1 + s1) domain: - d0 in [0, 7) - d1 in [0, 17) - s0 in [0, 3) - s1 in [0, 2) - d0 * 2 + s0 in [1, 14) - d1 + s1 in [0, 17) + d0 in [0, 6] + d1 in [0, 16] + s0 in [0, 2] + s1 in [0, 1] + d0 * 2 + s0 in [1, 13] + d1 + s1 in [0, 16] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 7) - d1 in [0, 17) + d0 in [0, 6] + d1 in [0, 16] )")))); } @@ -1972,16 +1972,16 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_BaseDilation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 floordiv 2, d1 floordiv 2) domain: - d0 in [0, 3) - d1 in [0, 5) - d0 mod 2 in [0, 1) - d1 mod 2 in [0, 1) + d0 in [0, 2] + d1 in [0, 4] + d0 mod 2 in [0, 0] + d1 mod 2 in [0, 0] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 3) - d1 in [0, 5) + d0 in [0, 2] + d1 in [0, 4] )")))); } @@ -2005,15 +2005,15 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_WindowDilation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0] -> (d0 + s0 * 3, d1) domain: - d0 in [0, 4) - d1 in [0, 3) - s0 in [0, 2) + d0 in [0, 3] + d1 in [0, 2] + s0 in [0, 1] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 4) - d1 in [0, 3) + d0 in [0, 3] + d1 in [0, 2] )")))); } @@ -2044,60 +2044,60 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (s0, d1 + s1) domain: - d0 in [0, 1) - d1 in [0, 2) - s0 in [0, 2) - s1 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (s0, d1 + s1) domain: - d0 in [0, 1) - d1 in [0, 2) - s0 in [0, 2) - s1 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 1) - d1 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 1) - d1 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] )")))); auto input_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); EXPECT_THAT(input_indexing_1.indexing_maps, ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (s0, d1 + s1) domain: - d0 in [0, 1) - d1 in [0, 2) - s0 in [0, 2) - s1 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (s0, d1 + s1) domain: - d0 in [0, 1) - d1 in [0, 2) - s0 in [0, 2) - s1 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] + s1 in [0, 1] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 1) - d1 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1) -> () domain: - d0 in [0, 1) - d1 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] )")))); } @@ -2116,24 +2116,24 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_NoPadding) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 + s0, d2 + s1, s2) domain: - d0 in [0, 1) - d1 in [0, 10) - d2 in [0, 6) - d3 in [0, 8) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) + d0 in [0, 0] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 1) - d1 in [0, 10) - d2 in [0, 6) - d3 in [0, 8) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) + d0 in [0, 0] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] )")))); } @@ -2152,26 +2152,26 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_PaddingAndWindowStride) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 * 2 + s0 - 1, d2 * 2 + s1 - 2, s2) domain: - d0 in [0, 1) - d1 in [0, 6) - d2 in [0, 5) - d3 in [0, 8) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) - d1 * 2 + s0 in [1, 13) - d2 * 2 + s1 in [2, 12) + d0 in [0, 0] + d1 in [0, 5] + d2 in [0, 4] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + d1 * 2 + s0 in [1, 12] + d2 * 2 + s1 in [2, 11] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 1) - d1 in [0, 6) - d2 in [0, 5) - d3 in [0, 8) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) + d0 in [0, 0] + d1 in [0, 5] + d2 in [0, 4] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] )")))); } @@ -2190,26 +2190,26 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_LhsDilation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, (d1 + s0) floordiv 2, (d2 + s1) floordiv 2, s2) domain: - d0 in [0, 1) - d1 in [0, 21) - d2 in [0, 15) - d3 in [0, 8) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) - (d1 + s0) mod 2 in [0, 1) - (d2 + s1) mod 2 in [0, 1) + d0 in [0, 0] + d1 in [0, 20] + d2 in [0, 14] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + (d1 + s0) mod 2 in [0, 0] + (d2 + s1) mod 2 in [0, 0] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 1) - d1 in [0, 21) - d2 in [0, 15) - d3 in [0, 8) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) + d0 in [0, 0] + d1 in [0, 20] + d2 in [0, 14] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] )")))); } @@ -2228,24 +2228,24 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_RhsDilation) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 + s0 * 2, d2 + s1 * 2, s2) domain: - d0 in [0, 1) - d1 in [0, 8) - d2 in [0, 2) - d3 in [0, 8) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) + d0 in [0, 0] + d1 in [0, 7] + d2 in [0, 1] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 1) - d1 in [0, 8) - d2 in [0, 2) - d3 in [0, 8) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) + d0 in [0, 0] + d1 in [0, 7] + d2 in [0, 1] + d3 in [0, 7] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] )")))); } @@ -2264,24 +2264,24 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_FeatureGroups) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (d0, d1 + s0, d2 + s1, (d3 floordiv 8) * 4 + s2) domain: - d0 in [0, 1) - d1 in [0, 10) - d2 in [0, 6) - d3 in [0, 48) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) + d0 in [0, 0] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 47] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 1) - d1 in [0, 10) - d2 in [0, 6) - d3 in [0, 48) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) + d0 in [0, 0] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 47] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] )")))); } @@ -2300,25 +2300,25 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_BatchGroups) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 + s3 * 2, d1 + s0, d2 + s1, s2) domain: - d0 in [0, 2) - d1 in [0, 10) - d2 in [0, 6) - d3 in [0, 21) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) - s3 in [0, 7) + d0 in [0, 1] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 20] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] + s3 in [0, 6] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3) domain: - d0 in [0, 2) - d1 in [0, 10) - d2 in [0, 6) - d3 in [0, 21) - s0 in [0, 3) - s1 in [0, 5) - s2 in [0, 4) + d0 in [0, 1] + d1 in [0, 9] + d2 in [0, 5] + d3 in [0, 20] + s0 in [0, 2] + s1 in [0, 4] + s2 in [0, 3] )")))); } @@ -2335,10 +2335,10 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: - d0 in [0, 1) - d1 in [0, 17) - d2 in [0, 9) - d3 in [0, 9) + d0 in [0, 0] + d1 in [0, 16] + d2 in [0, 8] + d3 in [0, 8] )")))); auto output_indexing = GetInputToOutputIndexing(root); @@ -2346,10 +2346,10 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3) domain: - d0 in [0, 1) - d1 in [0, 17) - d2 in [0, 9) - d3 in [0, 9) + d0 in [0, 0] + d1 in [0, 16] + d2 in [0, 8] + d3 in [0, 8] )")))); } @@ -2373,8 +2373,8 @@ TEST_F(IndexingAnalysisTest, ReverseReshape) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 10) - d1 in [0, 11) + d0 in [0, 9] + d1 in [0, 10] )")))); } @@ -2392,9 +2392,9 @@ TEST_F(IndexingAnalysisTest, SliceOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2) -> (d0 + 5, d1 * 7 + 3, d2 * 2) domain: - d0 in [0, 5) - d1 in [0, 3) - d2 in [0, 25) + d0 in [0, 4] + d1 in [0, 2] + d2 in [0, 24] )")))); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.indexing_maps, @@ -2405,11 +2405,11 @@ TEST_F(IndexingAnalysisTest, SliceOp) { d2 floordiv 2 ) domain: - d0 in [5, 10) - d1 in [3, 18) - d2 in [0, 49) - (d1 - 3) mod 7 in [0, 1) - d2 mod 2 in [0, 1) + d0 in [5, 9] + d1 in [3, 17] + d2 in [0, 48] + (d1 - 3) mod 7 in [0, 0] + d2 mod 2 in [0, 0] )")))); } @@ -2427,20 +2427,20 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: - d0 in [0, 3) - d1 in [0, 6) - d2 in [0, 128) - d3 in [0, 12288) + d0 in [0, 2] + d1 in [0, 5] + d2 in [0, 127] + d3 in [0, 12287] )")))); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.indexing_maps, ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2, d3, d1) domain: - d0 in [0, 3) - d1 in [0, 12288) - d2 in [0, 6) - d3 in [0, 128) + d0 in [0, 2] + d1 in [0, 12287] + d2 in [0, 5] + d3 in [0, 127] )")))); } @@ -2456,10 +2456,10 @@ TEST_F(IndexingAnalysisTest, TransposeOp4D) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d3, d1, d2) domain: - d0 in [0, 3) - d1 in [0, 6) - d2 in [0, 128) - d3 in [0, 12288) + d0 in [0, 2] + d1 in [0, 5] + d2 in [0, 127] + d3 in [0, 12287] )")))); } @@ -2478,26 +2478,26 @@ TEST_F(IndexingAnalysisTest, DotOp) { ElementsAre(ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> (d2, d1, s1, d3, s0, d0) domain: - d0 in [0, 10) - d1 in [0, 38) - d2 in [0, 4) - d3 in [0, 11) - d4 in [0, 16) - d5 in [0, 22) - s0 in [0, 18) - s1 in [0, 17) + d0 in [0, 9] + d1 in [0, 37] + d2 in [0, 3] + d3 in [0, 10] + d4 in [0, 15] + d5 in [0, 21] + s0 in [0, 17] + s1 in [0, 16] )")), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> (s1, d0, d4, s0, d5, d1) domain: - d0 in [0, 10) - d1 in [0, 38) - d2 in [0, 4) - d3 in [0, 11) - d4 in [0, 16) - d5 in [0, 22) - s0 in [0, 18) - s1 in [0, 17) + d0 in [0, 9] + d1 in [0, 37] + d2 in [0, 3] + d3 in [0, 10] + d4 in [0, 15] + d5 in [0, 21] + s0 in [0, 17] + s1 in [0, 16] )")))); } @@ -2558,8 +2558,8 @@ TEST_F(IndexingAnalysisTest, FusionWithUnsupportedOp) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0 * 6, d1 * 2) domain: - d0 in [0, 4) - d1 in [0, 3) + d0 in [0, 3] + d1 in [0, 2] )")), ElementsAre(UndefinedMap()), ElementsAre(UndefinedMap()))); } @@ -2577,16 +2577,16 @@ TEST_F(IndexingAnalysisTest, TilingIndexing) { d0 mod 4 + s2 * 4 ) domain: - d0 in [0, 16) - d1 in [0, 1) - d2 in [0, 1) - d3 in [0, 8192) - d4 in [0, 1) - d5 in [0, 1) - s0 in [0, 8) - s1 in [0, 1) - s2 in [0, 4) - (d3 floordiv 64) * 8 + s0 in [0, 1022) + d0 in [0, 15] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 8191] + d4 in [0, 0] + d5 in [0, 0] + s0 in [0, 7] + s1 in [0, 0] + s2 in [0, 3] + (d3 floordiv 64) * 8 + s0 in [0, 1021] )")); } @@ -2619,8 +2619,8 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing) { MatchIndexingString(R"( (d0, d1) -> (d1 * 1000 + d0) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )")); } @@ -2649,8 +2649,8 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing_NoEpilogue) { MatchIndexingString(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )")); } @@ -2670,16 +2670,16 @@ TEST_F(IndexingAnalysisTest, BroadcastingElementwise) { operand id = 0 (d0, d1) -> () domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] operand id = 1 (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] operand id = 2 (d0, d1) -> (d0, d1) domain: - d0 in [0, 1000) - d1 in [0, 1000) + d0 in [0, 999] + d1 in [0, 999] )")); } diff --git a/xla/service/gpu/model/indexing_map.cc b/xla/service/gpu/model/indexing_map.cc index ea2abb91f3d82c..ccb1cae91f0d20 100644 --- a/xla/service/gpu/model/indexing_map.cc +++ b/xla/service/gpu/model/indexing_map.cc @@ -793,8 +793,7 @@ std::string Interval::ToString() const { } void Interval::Print(std::ostream& out) const { - // The interval is printed as a semi-open one because it is easier to read. - out << '[' << lower << ", " << upper + 1 << ")"; + out << '[' << lower << ", " << upper << "]"; } int64_t Interval::GetLoopTripCount() const { diff --git a/xla/service/gpu/model/indexing_map_test.cc b/xla/service/gpu/model/indexing_map_test.cc index e0c458bc32b69c..cc6501f51bc7f7 100644 --- a/xla/service/gpu/model/indexing_map_test.cc +++ b/xla/service/gpu/model/indexing_map_test.cc @@ -80,13 +80,13 @@ TEST_F(IndexingMapTest, RTVar) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1)[range, rt_0, rt_1] -> (d1, d0, range + rt_0, rt_0) domain: - d0 in [0, 100) - d1 in [0, 44) - range in [-99, 100) - rt_0 in [0, 3) + d0 in [0, 99] + d1 in [0, 43] + range in [-99, 99] + rt_0 in [0, 2] hlo: NULL () -> () - rt_1 in [0, 8) + rt_1 in [0, 7] hlo: NULL () -> () )")); @@ -128,10 +128,10 @@ TEST_F(IndexingMapTest, Composition_Permutation) { EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 4) - s0 in [0, 2) - s1 in [0, 2) - s2 in [0, 4) + d0 in [0, 3] + s0 in [0, 1] + s1 in [0, 1] + s2 in [0, 3] )")); } @@ -147,10 +147,10 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 5) - s0 in [0, 7) - s1 in [0, 2) - s2 in [0, 6) + d0 in [0, 4] + s0 in [0, 6] + s1 in [0, 1] + s2 in [0, 5] )")); } @@ -174,26 +174,26 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 10) - s0 in [0, 70) - s1 in [0, 20) - s2 in [0, 8) - d0 + s2 in [0, 21) - d0 mod 8 in [0, 1) - s0 mod 3 in [1, 2) - s2 mod 4 in [0, 1) + d0 in [0, 9] + s0 in [0, 69] + s1 in [0, 19] + s2 in [0, 7] + d0 + s2 in [0, 20] + d0 mod 8 in [0, 0] + s0 mod 3 in [1, 1] + s2 mod 4 in [0, 0] )")); EXPECT_TRUE(composed.Simplify()); EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 9) - s0 in [1, 68) - s1 in [0, 20) - s2 in [0, 5) - d0 mod 8 in [0, 1) - s0 mod 3 in [1, 2) - s2 mod 4 in [0, 1) + d0 in [0, 8] + s0 in [1, 67] + s1 in [0, 19] + s2 in [0, 4] + d0 mod 8 in [0, 0] + s0 mod 3 in [1, 1] + s2 mod 4 in [0, 0] )")); } @@ -223,16 +223,16 @@ TEST_F(IndexingMapTest, Composition_RTVar) { EXPECT_THAT(composed.ToString(printer_), MatchIndexingString(R"( (d0, d1)[s, rt_0, rt_1, rt_2] -> (rt_0, d1 + rt_1, s + rt_2) domain: - d0 in [0, 1) - d1 in [0, 2) - s in [0, 32) - rt_0 in [0, 1) + d0 in [0, 0] + d1 in [0, 1] + s in [0, 31] + rt_0 in [0, 0] hlo: NULL () -> () - rt_1 in [0, 2) + rt_1 in [0, 1] hlo: NULL () -> () - rt_2 in [0, 227) + rt_2 in [0, 226] hlo: NULL () -> () )")); @@ -266,22 +266,22 @@ TEST_F(IndexingMapTest, Composition_OnlyRTVars) { (d0, d1)[ps_0, ps_1, cs_0, cs_1] -> (d0 + cs_0 * 2 + ps_0, d1 + cs_1 * 3 + ps_1 * 4) domain: - d0 in [0, 25) - d1 in [0, 16) - ps_0 in [0, 3) + d0 in [0, 24] + d1 in [0, 15] + ps_0 in [0, 2] hlo: NULL () -> () - ps_1 in [0, 2) + ps_1 in [0, 1] hlo: NULL () -> () - cs_0 in [0, 26) + cs_0 in [0, 25] hlo: NULL () -> () - cs_1 in [0, 17) + cs_1 in [0, 16] hlo: NULL () -> () - d0 + cs_0 * 2 in [0, 25) - d1 + cs_1 * 3 in [0, 16) + d0 + cs_0 * 2 in [0, 24] + d1 + cs_1 * 3 in [0, 15] )")); } @@ -298,12 +298,12 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, s0, s1) domain: - d0 in [0, 50) - d1 in [0, 60) - s0 in [0, 70) - s1 in [0, 20) - d0 + s0 in [1, 101) - s0 mod 3 in [0, 1) + d0 in [0, 49] + d1 in [0, 59] + s0 in [0, 69] + s1 in [0, 19] + d0 + s0 in [1, 100] + s0 mod 3 in [0, 0] )")); } @@ -318,9 +318,9 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0)[s0, s1] -> (s0, d0, s1) domain: - d0 in [0, 60) - s0 in [0, 70) - s1 in [0, 20) + d0 in [0, 59] + s0 in [0, 69] + s1 in [0, 19] )")); } @@ -335,9 +335,9 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0] -> (d0, d1, s0) domain: - d0 in [0, 50) - d1 in [0, 60) - s0 in [0, 20) + d0 in [0, 49] + d1 in [0, 59] + s0 in [0, 19] )")); } @@ -356,12 +356,12 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42) domain: - d0 in [0, 2) - d1 in [0, 4) - s0 in [0, 32) - s1 in [0, 96) - d0 + s0 * 4 + d1 in [24, 460) - s0 + s1 in [0, 513) + d0 in [0, 1] + d1 in [0, 3] + s0 in [0, 31] + s1 in [0, 95] + d0 + s0 * 4 + d1 in [24, 459] + s0 + s1 in [0, 512] )")); EXPECT_THAT(ConvertToSTL(unused_vars), ::testing::ElementsAreArray( @@ -381,12 +381,12 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, d0, s1) domain: - d0 in [0, 50) - d1 in [0, 60) - s0 in [0, 70) - s1 in [0, 20) - s0 + s1 in [1, 101) - s0 mod 3 in [0, 1) + d0 in [0, 49] + d1 in [0, 59] + s0 in [0, 69] + s1 in [0, 19] + s0 + s1 in [1, 100] + s0 mod 3 in [0, 0] )")); } @@ -401,9 +401,9 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0] -> (d1, d0, s0) domain: - d0 in [0, 50) - d1 in [0, 60) - s0 in [0, 20) + d0 in [0, 49] + d1 in [0, 59] + s0 in [0, 19] )")); } @@ -415,7 +415,7 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0) -> (d0) domain: - d0 in [0, 50) + d0 in [0, 49] )")); } @@ -469,10 +469,10 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0)[s0, s1] -> (d0 * 4 + s0 + s1 - 42) domain: - d0 in [0, 32) - s0 in [0, 2) - s1 in [0, 4) - d0 * 4 + s0 + s1 in [24, 460) + d0 in [0, 31] + s0 in [0, 1] + s1 in [0, 3] + d0 * 4 + s0 + s1 in [24, 459] )")); } @@ -493,12 +493,12 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithRTVars) { EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0)[s0, s1] -> (d0 * 4 + s0 + s1 - 42) domain: - d0 in [0, 32) - s0 in [0, 2) - s1 in [0, 4) + d0 in [0, 31] + s0 in [0, 1] + s1 in [0, 3] hlo: NULL () -> () - d0 * 4 + s0 + s1 in [24, 460) + d0 * 4 + s0 + s1 in [24, 459] )")); } @@ -512,8 +512,8 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0) domain: - d0 in [0, 100) - d0 mod 8 in [45, 50) + d0 in [0, 99] + d0 mod 8 in [45, 49] )")); } @@ -530,9 +530,9 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1) domain: - d0 in [0, 100) - s0 in [0, 2) - s1 in [0, 3) + d0 in [0, 99] + s0 in [0, 1] + s1 in [0, 2] )")); } @@ -559,8 +559,8 @@ TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0 * 6 + s0 * 3) domain: - d0 in [0, 100) - s0 in [0, 2) + d0 in [0, 99] + s0 in [0, 1] )")); } @@ -575,7 +575,7 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0) domain: - d0 in [40, 96) + d0 in [40, 95] )")); } @@ -591,8 +591,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: - d0 in [0, 100) - s0 in [-33, -12) + d0 in [0, 99] + s0 in [-33, -13] )")); } @@ -608,8 +608,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: - d0 in [0, 100) - s0 in [15, 36) + d0 in [0, 99] + s0 in [15, 35] )")); } @@ -624,7 +624,7 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0) -> (d0) domain: - d0 in [2, 5) + d0 in [2, 4] )")); } @@ -640,8 +640,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: - d0 in [0, 100) - s0 in [-3, -1) + d0 in [0, 99] + s0 in [-3, -2] )")); } @@ -657,8 +657,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0] -> (d0) domain: - d0 in [0, 100) - s0 in [2, 4) + d0 in [0, 99] + s0 in [2, 3] )")); } @@ -680,12 +680,12 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0, s1] -> (d0, s1, s0) domain: - d0 in [0, 4) - s0 in [-18, -5) - s1 in [1, 7) - d0 mod 3 in [0, 1) - s0 mod 6 in [0, 1) - s1 mod 5 in [1, 2) + d0 in [0, 3] + s0 in [-18, -6] + s1 in [1, 6] + d0 mod 3 in [0, 0] + s0 mod 6 in [0, 0] + s1 mod 5 in [1, 1] )")); } @@ -697,7 +697,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (5) domain: - d0 in [5, 6) + d0 in [5, 5] )")); } @@ -732,7 +732,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0 - 42) domain: - d0 in [53, 72) + d0 in [53, 71] )")); } @@ -743,7 +743,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ModIsAdd) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0 + 5) domain: - d0 in [-5, 0) + d0 in [-5, -1] )")); } @@ -765,8 +765,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsMod) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0 + s0 mod 3) domain: - d0 in [0, 2) - s0 in [0, 4) + d0 in [0, 1] + s0 in [0, 3] )")); } @@ -779,8 +779,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModMultiplied) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 mod 3) * 4 + s0 * 3) domain: - d0 in [0, 2) - s0 in [0, 4) + d0 in [0, 1] + s0 in [0, 3] )")); } @@ -793,8 +793,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModSum) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 + 1) mod 3) domain: - d0 in [0, 2) - s0 in [0, 4) + d0 in [0, 1] + s0 in [0, 3] )")); } @@ -807,8 +807,8 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 8) - d1 in [0, 16) + d0 in [0, 7] + d1 in [0, 15] )")); } @@ -825,9 +825,9 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1, d2) -> (d0, d1, d2) domain: - d0 in [0, 9) - d1 in [0, 9) - d2 in [0, 9) + d0 in [0, 8] + d1 in [0, 8] + d2 in [0, 8] )")); } @@ -844,9 +844,9 @@ TEST_F(IndexingMapTest, (d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, (d1 * 4 + d2) mod 8) domain: - d0 in [0, 10) - d1 in [0, 10) - d2 in [0, 10) + d0 in [0, 9] + d1 in [0, 9] + d2 in [0, 9] )")); } @@ -860,8 +860,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 8) - d1 in [0, 9) + d0 in [0, 7] + d1 in [0, 8] )")); } @@ -873,7 +873,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0 * 128) - domain: s0 in [0, 128) + domain: s0 in [0, 127] )")); } @@ -886,8 +886,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 128 + d1) domain: - d0 in [0, 1024) - d1 in [0, 128) + d0 in [0, 1023] + d1 in [0, 127] )")); } @@ -901,8 +901,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape3) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 4 + d1 * 512) domain: - d0 in [0, 128) - d1 in [0, 3072) + d0 in [0, 127] + d1 in [0, 3071] )")); } @@ -915,7 +915,7 @@ TEST_F(IndexingMapTest, EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> ((-d0) mod 2) domain: - d0 in [0, 128) + d0 in [0, 127] )")); } @@ -934,8 +934,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0 * 512 + d1 * 4) domain: - d0 in [0, 3072) - d1 in [0, 128) + d0 in [0, 3071] + d1 in [0, 127] )")); } @@ -948,7 +948,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (((s0 * 64) floordiv 715) * 715 + (s0 * 128) mod 715) - domain: s0 in [0, 128) + domain: s0 in [0, 127] )")); } @@ -962,7 +962,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0) domain: - s0 in [0, 1234) + s0 in [0, 1233] )")); } @@ -974,8 +974,8 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1] -> ((s0 * 128 + s1) floordiv 192) domain: - s0 in [0, 1234) - s1 in [0, 128) + s0 in [0, 1233] + s1 in [0, 127] )")); } @@ -987,7 +987,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivSumConstant) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> ((s0 * 2 + 3) floordiv 6) domain: - s0 in [0, 1234) + s0 in [0, 1233] )")); } @@ -1022,10 +1022,10 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { ((s0 * 114688 + s3 * 128 + s2) mod 5000) * 4 + s1 ) domain: - s0 in [0, 872) - s1 in [0, 4) - s2 in [0, 128) - s3 in [0, 896) + s0 in [0, 871] + s1 in [0, 3] + s2 in [0, 127] + s3 in [0, 895] )")); } @@ -1042,8 +1042,8 @@ TEST_F(IndexingMapTest, s0 * 4 + s1 floordiv 32 ) domain: - s0 in [0, 2) - s1 in [0, 128) + s0 in [0, 1] + s1 in [0, 127] )")); } @@ -1058,10 +1058,10 @@ TEST_F(IndexingMapTest, RescaleSymbols_Simple) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 4) - s0 in [0, 2) - s1 in [0, 2) - s2 in [0, 6) + d0 in [0, 3] + s0 in [0, 1] + s1 in [0, 1] + s2 in [0, 5] )")); } @@ -1078,10 +1078,10 @@ TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3) domain: - d0 in [0, 4) - s0 in [0, 7) - s1 in [0, 2) - s2 in [0, 6) + d0 in [0, 3] + s0 in [0, 6] + s1 in [0, 1] + s2 in [0, 5] )")); } @@ -1098,10 +1098,10 @@ TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: - d0 in [0, 4) - s0 in [0, 2) - s1 in [0, 2) - s2 in [0, 6) + d0 in [0, 3] + s0 in [0, 1] + s1 in [0, 1] + s2 in [0, 5] )")); } @@ -1118,11 +1118,11 @@ TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3) domain: - d0 in [0, 4) - s0 in [0, 2) - s1 in [0, 2) - s2 in [0, 6) - (s0 * 6 + 3) * s2 in [0, 29) + d0 in [0, 3] + s0 in [0, 1] + s1 in [0, 1] + s2 in [0, 5] + (s0 * 6 + 3) * s2 in [0, 28] )")); } @@ -1442,7 +1442,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Iota) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, d0) domain: - d0 in [0, 256) + d0 in [0, 255] )")); } @@ -1471,7 +1471,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_IotaAsConstant) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, 7) domain: - d0 in [0, 256) + d0 in [0, 255] )")); } @@ -1502,8 +1502,8 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, d0) domain: - d0 in [0, 255) - d0 mod 2 in [0, 1) + d0 in [0, 254] + d0 mod 2 in [0, 0] )")); } @@ -1535,7 +1535,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Broadcast) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, 11) domain: - d0 in [0, 32) + d0 in [0, 31] )")); } @@ -1576,7 +1576,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ChainedNoncomputeOps) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, (d0 floordiv 12) * -4 + 8) domain: - d0 in [0, 36) + d0 in [0, 35] )")); } @@ -1609,8 +1609,8 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0, s0) domain: - d0 in [0, 24) - s0 in [0, 513) + d0 in [0, 23] + s0 in [0, 512] hlo: %constant = s64[12]{0} constant({...}) (d0) -> (d0 floordiv 2) )")); @@ -1646,7 +1646,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Add) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, d0 * 2 + 42) domain: - d0 in [0, 12) + d0 in [0, 11] )")); } @@ -1685,7 +1685,7 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_Multiply) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (d0, (-d0 + 11) * d0) domain: - d0 in [0, 12) + d0 in [0, 11] )")); } @@ -1721,8 +1721,8 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartiallyOptimizableAdd) { EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0)[s0] -> (d0, d0 * 2 + s0) domain: - d0 in [0, 12) - s0 in [0, 12) + d0 in [0, 11] + s0 in [0, 11] hlo: %constant = s64[12]{0} constant({...}) (d0) -> (d0) )")); diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 0a359415a153d0..a9680f0f5fdb07 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -136,8 +136,8 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1) -> (d0, d1 * 10) domain: - d0 in [0, 2) - d1 in [0, 10) + d0 in [0, 1] + d1 in [0, 9] )")); auto p0_from_subtract0 = root->operand(0); @@ -149,8 +149,8 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1) -> (d0, d1 * 10) domain: - d0 in [0, 2) - d1 in [0, 10) + d0 in [0, 1] + d1 in [0, 9] )")); EXPECT_THAT(*p0_from_subtract1, MatchTiledHloInstruction( @@ -159,8 +159,8 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1) -> (d0, 0) domain: - d0 in [0, 2) - d1 in [0, 10) + d0 in [0, 1] + d1 in [0, 9] )")); } @@ -251,8 +251,8 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1) -> (d0, 0) domain: - d0 in [0, 2) - d1 in [0, 1) + d0 in [0, 1] + d1 in [0, 0] )")); } @@ -284,9 +284,9 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1, d2) -> (d0 * 2, d1 * 4, d2 * 2) domain: - d0 in [0, 2) - d1 in [0, 2) - d2 in [0, 8) + d0 in [0, 1] + d1 in [0, 1] + d2 in [0, 7] )")); EXPECT_THAT(*root->operand(0), @@ -295,9 +295,9 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1, d2) -> (d1 * 4, d2 * 2, d0 * 2) domain: - d0 in [0, 2) - d1 in [0, 2) - d2 in [0, 8) + d0 in [0, 1] + d1 in [0, 1] + d2 in [0, 7] )")); } @@ -333,8 +333,8 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1) -> (d0 * 2, d1 * 2) domain: - d0 in [0, 2) - d1 in [0, 4) + d0 in [0, 1] + d1 in [0, 3] )")); EXPECT_THAT(*p0_from_slice0, @@ -343,8 +343,8 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1) -> (d0 * 2, d1 * 2 + 2) domain: - d0 in [0, 2) - d1 in [0, 4) + d0 in [0, 1] + d1 in [0, 3] )")); EXPECT_THAT(*p0_from_slice1, @@ -353,8 +353,8 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1) -> (d0 * 2 + 3, d1 * 2 + 4) domain: - d0 in [0, 2) - d1 in [0, 4) + d0 in [0, 1] + d1 in [0, 3] )")); } @@ -472,10 +472,10 @@ ENTRY main { EXPECT_THAT(conjunction, SizeIs(2)); // We expect the constraints here to be - // 6 mod s0 in [0, 1) && 8 mod s1 in [0, 1) || - // 6 mod s0 in [0, 1) && s1 mod 8 in [0, 1) || - // 8 mod s1 in [0, 1) && s0 mod 6 in [0, 1) || - // s0 mod 6 in [0, 1) && s1 mod 8 in [0, 1) + // 6 mod s0 in [0, 0] && 8 mod s1 in [0, 0] || + // 6 mod s0 in [0, 0] && s1 mod 8 in [0, 0] || + // 8 mod s1 in [0, 0] && s0 mod 6 in [0, 0] || + // s0 mod 6 in [0, 0] && s1 mod 8 in [0, 0] // Tile sizes {6, 8} satisfy these constraints. std::vector possible_tile_parameters({6, 8}); EXPECT_THAT(analysis->ParametersSatisfyConstraints(possible_tile_parameters), @@ -626,7 +626,7 @@ ENTRY main { std::vector good_tilings, analysis.GetGoodTilings()); // The constraint on the 1st dimension is - // 6 mod s0 in [0, 1) || s0 mod 6 in [0, 1), + // 6 mod s0 in [0, 0] || s0 mod 6 in [0, 0], // and only 48, 1, and 2 fulfill it from the set of possible tile sizes // (1, 2, 4, 8, 16, 32, 48). // There is no constraint on the 2nd dimension. @@ -801,8 +801,8 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1) -> (d0, d1) domain: - d0 in [0, 65538) - d1 in [0, 32768) + d0 in [0, 65537] + d1 in [0, 32767] )")); } @@ -856,8 +856,8 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1) -> (0, d1, 0) domain: - d0 in [0, 1) - d1 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] )")); EXPECT_THAT(*param_0_tile, MatchTiledHloInstruction( @@ -866,12 +866,12 @@ ENTRY main { /*tile_offsets_indexing=*/R"( (d0, d1)[s0, s1] -> (s0, d1, s1) domain: - d0 in [0, 1) - d1 in [0, 2) - s0 in [0, 2) + d0 in [0, 0] + d1 in [0, 1] + s0 in [0, 1] hlo: %of1 = s32[] parameter(1) (d0, d1, d2) -> () - s1 in [0, 227) + s1 in [0, 226] hlo: %of3 = s32[] parameter(3) (d0, d1, d2) -> () )")); diff --git a/xla/service/gpu/model/symbolic_tile_test.cc b/xla/service/gpu/model/symbolic_tile_test.cc index d056dbbd70b38f..1db55375c0cc84 100644 --- a/xla/service/gpu/model/symbolic_tile_test.cc +++ b/xla/service/gpu/model/symbolic_tile_test.cc @@ -121,7 +121,7 @@ TEST_F(SymbolicTileTest, size_map: (d0, d1) -> (1, (d0 + 5) floordiv 6, d0 - ((d0 - 1) floordiv 6) * 6, d1) stride_map: (d0, d1) -> (0, 1, 1, 1) constraints: - 6 mod d0 in [0, 1) || d0 mod 6 in [0, 1) + 6 mod d0 in [0, 0] || d0 mod 6 in [0, 0] )"))); } @@ -150,12 +150,12 @@ TEST_F(SymbolicTileTest, (((-d2 + 7) floordiv 6) * (((-d1 + 9) floordiv 8) * ((-((-d0 + 5) floordiv 4) + 1) * 48) + (-((-d1 + 9) floordiv 8) + 1) * 6) + -((-d2 + 7) floordiv 6) + 1, 1) - constraints: d0 in [1, 2) && d1 in [1, 2) || - d0 in [1, 2) && d2 in [1, 2) || - d0 in [1, 2) && d2 in [6, 7) || - d1 in [1, 2) && d2 in [1, 2) || - d1 in [8, 9) && d2 in [1, 2) || - d1 in [8, 9) && d2 in [6, 7) + constraints: d0 in [1, 1] && d1 in [1, 1] || + d0 in [1, 1] && d2 in [1, 1] || + d0 in [1, 1] && d2 in [6, 6] || + d1 in [1, 1] && d2 in [1, 1] || + d1 in [8, 8] && d2 in [1, 1] || + d1 in [8, 8] && d2 in [6, 6] )"))); // Capturing elements along dimensions 0, 1, and 2 makes the stride equal to @@ -406,10 +406,10 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughDynamicSlice) { size_map: (d0, d1, d2) -> (1, d1, d2) stride_map: (d0, d1, d2) -> (0, 1, 1) rt_vars: - s0 in [0, 2) + s0 in [0, 1] hlo: %of1 = s32[] parameter(1) (d0, d1, d2) -> () - s1 in [0, 227) + s1 in [0, 226] hlo: %of3 = s32[] parameter(3) (d0, d1, d2) -> () )"))); @@ -458,10 +458,10 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughDynamicUpdateSlice) { size_map: (d0, d1) -> (d0, d1) stride_map: (d0, d1) -> (1, 1) rt_vars: - s0 in [0, 16) + s0 in [0, 15] hlo: %of1 = s32[] parameter(2) (d0, d1) -> () - s1 in [0, 21) + s1 in [0, 20] hlo: %of2 = s32[] parameter(3) (d0, d1) -> () )"))); @@ -501,10 +501,10 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughGather) { size_map: (d0, d1, d2, d3) -> (d1, d2, d3) stride_map: (d0, d1, d2, d3) -> (1, 1, 1) rt_vars: - s0 in [0, 27) + s0 in [0, 26] hlo: %indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 0) - s1 in [0, 69) + s1 in [0, 68] hlo: %indices = s32[1806,2]{1,0} parameter(1) (d0, d1, d2, d3) -> (d0, 1) )"))); @@ -694,10 +694,10 @@ TEST_F(SymbolicTileTest, CanCombineCompatibleConstraints) { size_map: (d0, d1) -> (1, (d0 + 5) floordiv 6, d0 - ((d0 - 1) floordiv 6) * 6, (d1 + 7) floordiv 8, d1 - ((d1 - 1) floordiv 8) * 8) stride_map: (d0, d1) -> (0, 1, 1, 1, 1) constraints: - 6 mod d0 in [0, 1) && 8 mod d1 in [0, 1) || - 6 mod d0 in [0, 1) && d1 mod 8 in [0, 1) || - 8 mod d1 in [0, 1) && d0 mod 6 in [0, 1) || - d0 mod 6 in [0, 1) && d1 mod 8 in [0, 1) + 6 mod d0 in [0, 0] && 8 mod d1 in [0, 0] || + 6 mod d0 in [0, 0] && d1 mod 8 in [0, 0] || + 8 mod d1 in [0, 0] && d0 mod 6 in [0, 0] || + d0 mod 6 in [0, 0] && d1 mod 8 in [0, 0] )"))); } @@ -720,7 +720,7 @@ TEST_F(SymbolicTileTest, offset_map: (d0, d1, d2) -> (0, 0) size_map: (d0, d1, d2) -> (d0 * d1, 50304) stride_map: (d0, d1, d2) -> (((-d1 + 2049) floordiv 2048) * ((-((-d0 + 5) floordiv 4) + 1) * 2048) + -((-d1 + 2049) floordiv 2048) + 1, 1) - constraints: d0 in [1, 2) || d1 in [1, 2) || d1 in [2048, 2049) + constraints: d0 in [1, 1] || d1 in [1, 1] || d1 in [2048, 2048] )"))); } @@ -801,7 +801,7 @@ TEST_F(ConstraintExpressionTest, PrettyPrintingTest) { constraints.Or(std::move(conjunction_1)); constraints.Or(std::move(conjunction_2)); EXPECT_THAT(constraints, MatchConstraintExpressionString( - "d0 in [0, 6) && d1 in [0, 6) || d2 in [0, 6)")); + "d0 in [0, 5] && d1 in [0, 5] || d2 in [0, 5]")); } TEST_F(ConstraintExpressionTest, @@ -809,11 +809,11 @@ TEST_F(ConstraintExpressionTest, ConstraintExpression constraints; constraints.And(GetConjointConstraints({{"d0", Interval{0, 5}}})); - EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [0, 6)")); + EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [0, 5]")); // Constraints are intersected. constraints.And(GetConjointConstraints({{"d0", Interval{3, 6}}})); - EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [3, 6)")); + EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [3, 5]")); // Empty intersection results in unsatisfiability. constraints.And(GetConjointConstraints({{"d0", Interval{7, 8}}})); @@ -862,7 +862,7 @@ TEST_F( constraints.Or(std::move(conjunction_1)); constraints.Or(std::move(conjunction_2)); EXPECT_THAT(constraints, - MatchConstraintExpressionString("d0 in [0, 6) || d1 in [0, 6)")); + MatchConstraintExpressionString("d0 in [0, 5] || d1 in [0, 5]")); // `conjunction_1` && `conjunction_3` is an unsatisfiable constraint. Taking // the conjunction of the existing constraint expression with `conjunction_3` @@ -873,7 +873,7 @@ TEST_F( constraints.And(std::move(conjunction_3)); EXPECT_THAT(constraints, - MatchConstraintExpressionString("d0 in [6, 7) && d1 in [0, 6)")); + MatchConstraintExpressionString("d0 in [6, 6] && d1 in [0, 5]")); // But becomes unsatisfiable if we eliminate the last remaining constraint by // constructing another unsatisfiable conjunction. @@ -1102,7 +1102,7 @@ TEST_F(ConstraintExpressionTest, constraints.Simplify(); EXPECT_THAT(constraints, - MatchConstraintExpressionString("d0 in [0, 2) && d1 in [0, 2)")); + MatchConstraintExpressionString("d0 in [0, 1] && d1 in [0, 1]")); } TEST_F(ConstraintExpressionTest, @@ -1119,7 +1119,7 @@ TEST_F(ConstraintExpressionTest, constraints.Simplify(); - EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [0, 2)")); + EXPECT_THAT(constraints, MatchConstraintExpressionString("d0 in [0, 1]")); } TEST_F(ConstraintExpressionTest, SimplifyKeepsAlwaysSatisfiedUnchanged) { @@ -1156,11 +1156,11 @@ TEST_F(ConstraintExpressionTest, SimplifyRemovesRedundantConstraints) { constraints.Simplify(); - // We could simplify those contraints even further to `d0 in [0, 1)` by + // We could simplify those contraints even further to `d0 in [0, 0]` by // checking that one conjunction is a subset of the other, but we don't do // that yet. EXPECT_THAT(constraints, MatchConstraintExpressionString( - "d0 in [0, 1) || d0 in [0, 1) && d1 in [1, 2)")); + "d0 in [0, 0] || d0 in [0, 0] && d1 in [1, 1]")); } } // namespace From 1c560480a1b34e220c192cbb8f2d244ee5304abb Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Thu, 25 Jul 2024 04:22:15 -0700 Subject: [PATCH 142/376] [XLA:GPU] Reserve elements in hash set. PiperOrigin-RevId: 655905499 --- xla/service/gpu/model/symbolic_tile_analysis.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/xla/service/gpu/model/symbolic_tile_analysis.cc b/xla/service/gpu/model/symbolic_tile_analysis.cc index ea69eadb079bac..6aed0c540ba910 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -191,6 +191,11 @@ class OrderedUniquePtrValueHashSet { return {*it, inserted}; } + void Reserve(int64_t n) { + hash_set_.reserve(n); + data_.reserve(n); + } + // Moves data out of the set. std::vector> ExtractData() { return std::move(data_); } @@ -492,6 +497,11 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( OrderedUniquePtrValueHashSet tiled_hlo_instructions_set; absl::flat_hash_map symbolic_to_tiled_hlo_map; + // The actual number of TiledHloInstructions can be smaller than the number of + // SymbolicTiledHloInstructions, because some instruction will be + // deduplicated, but we reserve to the upper bound to avoid reallocations and + // additional hash calculations. + tiled_hlo_instructions_set.Reserve(symbolic_tiled_hlo_instructions_.size()); std::function( const SymbolicTiledHloInstruction*)> From 656859cf4eccf7a1d42825f8126d52500635d6aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Thu, 25 Jul 2024 04:39:14 -0700 Subject: [PATCH 143/376] [XLA:CPU] Support `reshape` op in thunks runtime. In most cases `reshape` op is rewritten by `ReshapeDecomposer` pass. When it is not, current runtime defaults to elemental IR emitter. This CL aligns thunks runtime behavior with the current one. Also added a test case for reshape that runs without HLO passes. PiperOrigin-RevId: 655908939 --- xla/service/cpu/thunk_emitter.cc | 1 + xla/tests/BUILD | 2 ++ xla/tests/reshape_test.cc | 17 +++++++++++++++++ 3 files changed, 20 insertions(+) diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index b8194d82fd6f96..4f3d45e4b47307 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -225,6 +225,7 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kRemainder: + case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: case HloOpcode::kRoundNearestEven: diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 094f3d1ddf4e9e..49ed0744e11310 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2099,11 +2099,13 @@ xla_test( tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", + ":hlo_test_base", ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", "//xla:array4d", + "//xla:error_spec", "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", diff --git a/xla/tests/reshape_test.cc b/xla/tests/reshape_test.cc index d25518e16e3437..9e3c09dd12ffc0 100644 --- a/xla/tests/reshape_test.cc +++ b/xla/tests/reshape_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/status/statusor.h" @@ -26,6 +27,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" +#include "xla/error_spec.h" #include "xla/layout_util.h" #include "xla/literal_util.h" #include "xla/reference_util.h" @@ -33,6 +35,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/xla_data.pb.h" @@ -1021,5 +1024,19 @@ INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, ::testing::ValuesIn(std::vector{false})); #endif +using ReshapeHloTest = HloTestBase; + +TEST_F(ReshapeHloTest, NoHloPasses) { + const std::string hlo_string = R"( + HloModule Bug, is_scheduled=true + + ENTRY entry { + %p0 = u32[1,35]{1,0} parameter(0) + %reshape.4 = u32[35]{0} reshape(u32[1,35]{1,0} %p0) + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.01, 0.01})); +} + } // namespace } // namespace xla From 147b9bf88371114fc861f2fd282fa37d53d0711a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Jul 2024 04:50:31 -0700 Subject: [PATCH 144/376] Automated Code Change PiperOrigin-RevId: 655911413 --- xla/pjrt/cpu/cpu_topology.h | 4 ---- xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h | 1 - 2 files changed, 5 deletions(-) diff --git a/xla/pjrt/cpu/cpu_topology.h b/xla/pjrt/cpu/cpu_topology.h index ade4d7a018a8fb..eb337325758788 100644 --- a/xla/pjrt/cpu/cpu_topology.h +++ b/xla/pjrt/cpu/cpu_topology.h @@ -69,10 +69,6 @@ inline int UnpackCpuProcessIndex(PjRtGlobalDeviceId global_device_id) { return global_device_id.value() / kMaxCpuDevicesPerProcess; } -inline int UnpackCpuDeviceId(PjRtGlobalDeviceId global_device_id) { - return global_device_id.value() % kMaxCpuDevicesPerProcess; -} - } // namespace xla #endif // XLA_PJRT_CPU_CPU_TOPOLOGY_H_ diff --git a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h index 04a8fa851a1fe5..4bfc1c57aed269 100644 --- a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h +++ b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h @@ -76,7 +76,6 @@ class MaybeOwningCpuMemory { void* data() const { return buf_; } size_t size() const { return size_; } - bool owns_data() const { return data_ != nullptr; } private: void* buf_ = nullptr; // Non-owning data pointer. From d286ad15f26d8f9cc660ba3ee3a53a5148d46d33 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 25 Jul 2024 04:57:31 -0700 Subject: [PATCH 145/376] Fix floordiv simplification bug. The inner divisor can't be reused after new_dividend was changed, since some ad-hoc simplifications may have been applied. PiperOrigin-RevId: 655912799 --- xla/service/gpu/model/indexing_map.cc | 17 ++++++++--------- xla/service/gpu/model/indexing_map_test.cc | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/xla/service/gpu/model/indexing_map.cc b/xla/service/gpu/model/indexing_map.cc index ccb1cae91f0d20..f0d34e8871e8f2 100644 --- a/xla/service/gpu/model/indexing_map.cc +++ b/xla/service/gpu/model/indexing_map.cc @@ -293,15 +293,6 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, return expr; }); - std::optional inner_divisor = std::nullopt; - int num_inner_divisors = 0; - VisitSummands(new_dividend, [&](AffineExpr summand) { - if (auto divisor = GetConstantRhs(summand, AffineExprKind::FloorDiv)) { - inner_divisor = divisor; - ++num_inner_divisors; - } - }); - // Split `new_dividend` into `multiplied * multiplier_gcd + not_multiplied`. auto [multiplied, multiplier_gcd, not_multiplied] = SplitSumByGcd(new_dividend); @@ -338,6 +329,14 @@ AffineExpr AffineExprSimplifier::SimplifySumDiv(AffineExpr dividend, // If a0 is 16 and a1 is 2, the result is `(5 + 0) / 6 = 0`, whereas the // rewritten form `(a0 + a1) / 18` evaluates to 1. This can only happen when // there is more than one division. + std::optional inner_divisor = std::nullopt; + int num_inner_divisors = 0; + VisitSummands(new_dividend, [&](AffineExpr summand) { + if (auto divisor = GetConstantRhs(summand, AffineExprKind::FloorDiv)) { + inner_divisor = divisor; + ++num_inner_divisors; + } + }); if (num_inner_divisors == 1) { new_dividend = MapSummands(new_dividend, [&](AffineExpr summand) { if (auto inner_divisor = diff --git a/xla/service/gpu/model/indexing_map_test.cc b/xla/service/gpu/model/indexing_map_test.cc index cc6501f51bc7f7..3624a01eb44b6b 100644 --- a/xla/service/gpu/model/indexing_map_test.cc +++ b/xla/service/gpu/model/indexing_map_test.cc @@ -724,6 +724,21 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression2) { EXPECT_FALSE(indexing_map.Simplify()); } +TEST_F(IndexingMapTest, AffineMapSimplification_FloorDivRegression) { + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap( + "(d0, d1) -> (((d0 floordiv 3) * 3 + d1 floordiv 2) floordiv 6)", + &mlir_context_), + {12, 6}, {}); + EXPECT_TRUE(indexing_map.Simplify()); + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0, d1) -> (d0 floordiv 6) + domain: + d0 in [0, 11] + d1 in [0, 5] + )")); +} + TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { IndexingMap indexing_map( ParseAffineMap("(d0) -> (d0 mod 42)", &mlir_context_), {{53, 71}}, {}, From 540e996ca52ec9ae91d07177b8b083dddc5340cb Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Thu, 25 Jul 2024 05:04:18 -0700 Subject: [PATCH 146/376] Use kernel_index from backend_config to select custom fusion kernel. PiperOrigin-RevId: 655914480 --- xla/service/gpu/backend_configs.proto | 5 ++ .../gpu/custom_kernel_fusion_rewriter.cc | 1 + .../gpu/custom_kernel_fusion_rewriter_test.cc | 4 +- .../gpu/dynamic_slice_fusion_rewriter_test.cc | 46 +++++++++---------- xla/service/gpu/fusions/custom.cc | 15 ++---- .../gpu/kernels/cutlass_gemm_fusion_test.cc | 18 ++++---- 6 files changed, 45 insertions(+), 44 deletions(-) diff --git a/xla/service/gpu/backend_configs.proto b/xla/service/gpu/backend_configs.proto index c21eb34a4ad3ab..1a048c7e40b5eb 100644 --- a/xla/service/gpu/backend_configs.proto +++ b/xla/service/gpu/backend_configs.proto @@ -148,6 +148,11 @@ message ReificationCost { // fusion computation). message CustomFusionConfig { string name = 1; + + // When a custom fusion has multiple kernels, this field specifies which + // kernel to use. Default value is to select the first kernel, i.e. + // kernel_index = 0. + int32 kernel_index = 2; } message CuDnnFusionConfig { diff --git a/xla/service/gpu/custom_kernel_fusion_rewriter.cc b/xla/service/gpu/custom_kernel_fusion_rewriter.cc index cf80ab5fcb3e4a..814ccf05003804 100644 --- a/xla/service/gpu/custom_kernel_fusion_rewriter.cc +++ b/xla/service/gpu/custom_kernel_fusion_rewriter.cc @@ -168,6 +168,7 @@ static absl::StatusOr CreateFusionInstruction( *gpu_config.mutable_fusion_backend_config(); backend_config.set_kind("__custom_fusion"); *backend_config.mutable_custom_fusion_config() = match.config(); + backend_config.mutable_custom_fusion_config()->set_kernel_index(0); TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config))); // If we don't have workspace we can return constructed fusion instruction. diff --git a/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc b/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc index ac0d1464f3612d..f2c824cac7e1f7 100644 --- a/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc +++ b/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc @@ -79,7 +79,7 @@ TEST_F(CustomKernelFusionRewriterTest, SimpleGemm) { ; CHECK: kind=kCustom, calls=%simple_gemm, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"simple_gemm"} + ; CHECK: "custom_fusion_config":{"name":"simple_gemm","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -121,7 +121,7 @@ TEST_F(CustomKernelFusionRewriterTest, SimpleGemmWithWorkspace) { ; CHECK: kind=kCustom, calls=%simple_gemm, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"simple_gemm"} + ; CHECK: "custom_fusion_config":{"name":"simple_gemm","kernel_index":0} ; CHECK: } ; CHECK: ROOT {{.*}} get-tuple-element([[FUSION]]), index=0 ; CHECK: } diff --git a/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc b/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc index 8a7f63297100d3..3d3eef1e4a3687 100644 --- a/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc +++ b/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc @@ -92,7 +92,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemm) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -156,7 +156,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWithWorkspace) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -221,7 +221,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmWorkspaceIgnored) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0 ; CHECK: } @@ -283,7 +283,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmNotRoot) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]]) ; CHECK: } @@ -347,7 +347,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandHasMultipleUsers) { ; CHECK-DAG: kind=kCustom, calls=%address-computation, ; CHECK-DAG: backend_config={ ; CHECK-DAG: "kind":"__custom_fusion", - ; CHECK-DAG: "custom_fusion_config":{"name":"address_computation"} + ; CHECK-DAG: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK-DAG: } ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]} ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) @@ -496,7 +496,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmSlicingNotParameter) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]]) ; CHECK: } @@ -657,7 +657,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmDuplicateOperand) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -719,7 +719,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -781,7 +781,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmReverseOperandOrder2) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -844,7 +844,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandAliasingOutput) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -901,7 +901,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleGemmOperandsFromSameSlice) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -963,7 +963,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCall) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -1016,7 +1016,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCallLegacy) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -1077,7 +1077,7 @@ TEST_F(DynamicSliceFusionRewriterTest, TupleSliceCustomCallLegacy) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -1157,7 +1157,7 @@ TEST_F(DynamicSliceFusionRewriterTest, TupledOutputCustomCallLegacy) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"address_computation"} + ; CHECK: "custom_fusion_config":{"name":"address_computation","kernel_index":0} ; CHECK: } ; CHECK-DAG: [[GTE6:%[^ ]+]] = f32[1024]{0} get-tuple-element([[FUSION]]), index=2 ; CHECK-DAG: [[GTE7:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[FUSION]]), index=1 @@ -1254,7 +1254,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemm) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -1323,7 +1323,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWithWorkspace) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -1392,7 +1392,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmWorkspaceIgnored) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0} ; CHECK: } ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[8,8]{1,0} get-tuple-element([[FUSION]]), index=0 ; CHECK: } @@ -1458,7 +1458,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DynamicSimpleGemmNotRoot) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0} ; CHECK: } ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[FUSION]]) ; CHECK: } @@ -1522,7 +1522,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemm) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -1593,7 +1593,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmNotRoot) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0} ; CHECK: } ; CHECK: ROOT {{.*}} = f16[4,8,8]{2,1,0} log([[FUSION]]) ; CHECK: } @@ -1672,7 +1672,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWithWorkspace) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0} ; CHECK: } ; CHECK: [[DUS_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0 ; CHECK: [[WORKSPACE_MAIN:%[^ ]+]] = s8[256]{0} get-tuple-element([[FUSION]]), index=1 @@ -1742,7 +1742,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWorkspaceIgnored) { ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation"} + ; CHECK: "custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0} ; CHECK: } ; CHECK: ROOT [[DOT_MAIN:%[^ ]+]] = f16[4,8,8]{2,1,0} get-tuple-element([[FUSION]]), index=0 ; CHECK: } diff --git a/xla/service/gpu/fusions/custom.cc b/xla/service/gpu/fusions/custom.cc index 6e1ea915631701..30c92b9387f53a 100644 --- a/xla/service/gpu/fusions/custom.cc +++ b/xla/service/gpu/fusions/custom.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" @@ -765,7 +764,7 @@ absl::StatusOr CustomFusion::Emit( fusion.backend_config()); const FusionBackendConfig& backend_config = gpu_config.fusion_backend_config(); - const auto& config = backend_config.custom_fusion_config(); + const CustomFusionConfig& config = backend_config.custom_fusion_config(); VLOG(3) << "Lower HLO fusion to a custom fusion " << config.name(); @@ -794,14 +793,10 @@ absl::StatusOr CustomFusion::Emit( " returned empty custom kernels for a fused computation")); } - // TODO(ezhulenev): Add support for auto tuning to select the best kernel. - if (kernels.size() != 1) { - return absl::InternalError("Expected exactly one custom kernel"); - } - - TF_ASSIGN_OR_RETURN( - auto thunk, BuildCustomKernelThunkForFusion(ir_emitter_context, fusion, - std::move(kernels[0]))); + TF_ASSIGN_OR_RETURN(auto thunk, + BuildCustomKernelThunkForFusion( + ir_emitter_context, fusion, + std::move(kernels[config.kernel_index()]))); FusionEmissionResult result; result.thunks.push_back(std::move(thunk)); diff --git a/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index 1fec453765c4d8..a6488e9602045b 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -81,7 +81,7 @@ TEST_F(CutlassFusionTest, RowMajorGemm) { ; CHECK: kind=kCustom, calls=%cutlass_gemm, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"cutlass_gemm"} + ; CHECK: "custom_fusion_config":{"name":"cutlass_gemm","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -121,7 +121,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcast) { ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_upcast, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", - ; CHECK: "custom_fusion_config":{"name":"cutlass_gemm_with_upcast"} + ; CHECK: "custom_fusion_config":{"name":"cutlass_gemm_with_upcast","kernel_index":0} ; CHECK: } ; CHECK: } )"; @@ -170,7 +170,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) { ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", ; CHECK: "custom_fusion_config":{ - ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice" + ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice","kernel_index":0 ; CHECK: } ; CHECK: } ; CHECK: } @@ -224,7 +224,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceMultipleUses) { ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", ; CHECK: "custom_fusion_config":{ - ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice" + ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice","kernel_index":0 ; CHECK: } ; CHECK: } ; CHECK: [[SLICE:%[^ ]+]] = f32[1,2,2]{2,1,0} dynamic-slice( @@ -275,7 +275,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceWithoutBitcast) { ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", ; CHECK: "custom_fusion_config":{ - ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice" + ; CHECK: "name":"cutlass_gemm_with_dynamic_update_slice","kernel_index":0 ; CHECK: } ; CHECK: } ; CHECK: } @@ -322,7 +322,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmKernel) { arg0 = f32[100,784]{1,0} parameter(0) arg1 = f32[784,10]{1,0} parameter(1) ROOT _ = f32[100,10]{1,0} fusion(arg0, arg1), kind=kCustom, calls=cutlass_gemm, - backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm"}}} + backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm", "kernel_index":0}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, @@ -362,7 +362,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { p0 = bf16[16,32]{1,0} parameter(0) p1 = s8[32,8]{1,0} parameter(1) ROOT _ = bf16[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast, - backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast"}}} + backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast", "kernel_index":0}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, @@ -419,7 +419,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernel) { p3 = s32[] parameter(3) r.0 = (bf16[2,8,8]{2,1,0}, u8[1024]{0}) fusion(p1, p0, p2, p3), kind=kCustom, calls=%cutlass_gemm, - backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm_with_dynamic_update_slice"}}} + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm_with_dynamic_update_slice", "kernel_index":0}}} ROOT %get-tuple-element = bf16[2,8,8]{2,1,0} get-tuple-element(r.0), index=0 })"; @@ -492,7 +492,7 @@ TEST_F(CutlassFusionTest, p3 = s32[] parameter(3) r.0 = (bf16[16,8]{1,0}, u8[1024]{0}) fusion(p1, p0, p2, p3), kind=kCustom, calls=%cutlass_gemm, - backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm_with_dynamic_update_slice"}}} + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm_with_dynamic_update_slice", "kernel_index":0}}} ROOT %get-tuple-element = bf16[16,8]{1,0} get-tuple-element(r.0), index=0 })"; From 9d31fa14a70bd43e10817fa3a5f48b25f98c6da5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Jul 2024 05:11:43 -0700 Subject: [PATCH 147/376] [XLA:GPU] Comment properly sparse-local-load-to-llvm pass PiperOrigin-RevId: 655916198 --- .../gpu/fusions/triton/sparse_extensions.cc | 123 ++++++++++-------- 1 file changed, 66 insertions(+), 57 deletions(-) diff --git a/xla/service/gpu/fusions/triton/sparse_extensions.cc b/xla/service/gpu/fusions/triton/sparse_extensions.cc index 037631975e3486..8b2f1aba7ee14d 100644 --- a/xla/service/gpu/fusions/triton/sparse_extensions.cc +++ b/xla/service/gpu/fusions/triton/sparse_extensions.cc @@ -362,19 +362,17 @@ class SparseLocalLoadToLLVM LogicalResult matchAndRewrite( triton::gpu::LocalLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MemDescType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - if (isa(srcLayout) && - isa(dstLayout)) { - return lowerSharedToSparseMeta(op, adaptor, rewriter); - } - return failure(); + MemDescType src_ty = op.getSrc().getType(); + if (!isa(src_ty.getEncoding())) + return failure(); + RankedTensorType dst_ty = op.getType(); + if (!isa(dst_ty.getEncoding())) + return failure(); + return lowerSharedToSparseMeta(op, adaptor, rewriter); } private: - // shared -> sparse dot meta + // lowering metadata (local_load: shared -> sparse dot meta) to LLVM LogicalResult lowerSharedToSparseMeta( triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -384,75 +382,86 @@ class SparseLocalLoadToLLVM // such values in a register (32-bit). // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-matrix-storage constexpr int kTileSize = 16; - constexpr int kThreadsInGroup = 4; - constexpr int kMetadataElementsPerPackedValue = 8; // 8 x 2-bit = 16-bit - constexpr int kMetadataLineOffset = kThreadsPerWarp / kThreadsInGroup; + constexpr int kMetaElementsBitSize = 2; + // Metadata elements are packed into 16-bits values. + constexpr int kMetaElementsPerPackedValue = 16 / kMetaElementsBitSize; + constexpr int kColumnsPerCtaTile = kTileSize / kMetaElementsPerPackedValue; auto loc = op.getLoc(); - Value tensor = op.getSrc(); - auto sparseEncoding = cast( + auto load_sparse_encoding = cast( cast(op.getResult().getType()).getEncoding()); - auto llvmElemTy = getTypeConverter()->convertType( - cast(op.getSrc().getType()).getElementType()); - auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), - llvmElemTy, rewriter); // Calculate tile size as number of mask elements (4xi4). - NvidiaMmaEncodingAttr mmaLayout = - cast(sparseEncoding.getParent()); - SmallVector warpsPerCTA = mmaLayout.getWarpsPerCTA(); - SmallVector shapePerCTATile = { - kTileSize * warpsPerCTA[0], - kTileSize / kMetadataElementsPerPackedValue}; - Value strideM = smemObj.strides[0]; - Value strideK = smemObj.strides[1]; + NvidiaMmaEncodingAttr mma_layout = + cast(load_sparse_encoding.getParent()); + SmallVector warps_per_cta = mma_layout.getWarpsPerCTA(); // Calculate offset in the tile for the current thread. - Value threadsPerWarp = i32_val(kThreadsPerWarp); - Value thread = getThreadId(rewriter, loc); - Value warpId = udiv(thread, threadsPerWarp); - Value warpGroupId; - if (mmaLayout.isHopper()) { - warpGroupId = urem(warpId, i32_val(warpsPerCTA[0])); + Value threads_per_warp = i32_val(kThreadsPerWarp); + Value thread_id = getThreadId(rewriter, loc); + Value warp_id = udiv(thread_id, threads_per_warp); + Value warp_group_id; + if (mma_layout.isHopper()) { + // Hopper MMA instructions force a warp order of [0, 1]. See docs: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 + warp_group_id = urem(warp_id, i32_val(warps_per_cta[0])); } else { - assert(mmaLayout.isAmpere()); - warpGroupId = udiv(warpId, i32_val(warpsPerCTA[1])); + assert(mma_layout.isAmpere() && + "SparseDot is only supported on Ampere and Hopper"); + warp_group_id = udiv(warp_id, i32_val(warps_per_cta[1])); } - Value laneId = urem(thread, threadsPerWarp); - Value laneGroupId = udiv(laneId, i32_val(kThreadsInGroup)); - Value columnId = urem(laneId, i32_val(shapePerCTATile[1])); - Value rowId = add(mul(warpGroupId, i32_val(kTileSize)), laneGroupId); + // Calculate row and column id, based on mma.sp.sync.aligned.m16n8k32: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-mma-metadata-16832-f16bf16. + // column-id takes into consideration that we pack elements for metadata. + constexpr int kThreadsInGroup = 4; + constexpr int kMetadataLineOffset = kThreadsPerWarp / kThreadsInGroup; + Value lane_id = urem(thread_id, threads_per_warp); + Value lane_group_id = udiv(lane_id, i32_val(kThreadsInGroup)); + Value row_id = add(mul(warp_group_id, i32_val(kTileSize)), lane_group_id); + SmallVector shape_per_cta_tile = {kTileSize * warps_per_cta[0], + kColumnsPerCtaTile}; + Value column_id = urem(lane_id, i32_val(shape_per_cta_tile[1])); // Calculate number of tile repetitions. + Value tensor = op.getSrc(); auto shape = cast(tensor.getType()).getShape(); - int repM = shape[0] / shapePerCTATile[0]; - int repK = shape[1] / shapePerCTATile[1]; - assert(repM > 0 && repK > 0); + int rep_m = shape[0] / shape_per_cta_tile[0]; + int rep_k = shape[1] / shape_per_cta_tile[1]; + assert(rep_m > 0 && rep_k > 0); // Load sparse metadata from shared memory. + auto elem_ty = getTypeConverter()->convertType( + cast(tensor.getType()).getElementType()); + auto s_mem_obj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), elem_ty, rewriter); + Value stride_m = s_mem_obj.strides[0]; + Value stride_k = s_mem_obj.strides[1]; MLIRContext *ctx = tensor.getContext(); - Type ptrTy = ptr_ty(ctx, 3); - Value base = gep(ptrTy, i16_ty, smemObj.base, i32_val(0)); + Type ptr_ty = ptr_ty(ctx, 3); + Value base = gep(ptr_ty, i16_ty, s_mem_obj.base, i32_val(0)); SmallVector values; - for (int k = 0; k < repK; ++k) { - for (int m = 0; m < repM; ++m) { - Value row = add(rowId, i32_val(m * shapePerCTATile[0])); - Value column = add(columnId, i32_val(k * shapePerCTATile[1])); - Value offset1 = add(mul(row, strideM), mul(column, strideK)); - Value offset2 = - add(offset1, mul(i32_val(kMetadataLineOffset), strideM)); - Value lower = load(i16_ty, gep(ptrTy, i16_ty, base, offset1)); - Value upper = load(i16_ty, gep(ptrTy, i16_ty, base, offset2)); + for (int k = 0; k < rep_k; ++k) { + for (int m = 0; m < rep_m; ++m) { + // Each thread processes two different rows. + Value row_lower = add(row_id, i32_val(m * shape_per_cta_tile[0])); + Value row_upper = add(row_lower, i32_val(kMetadataLineOffset)); + Value column = add(column_id, i32_val(k * shape_per_cta_tile[1])); + Value offset_lower = + add(mul(row_lower, stride_m), mul(column, stride_k)); + Value offset_upper = + add(mul(row_upper, stride_m), mul(column, stride_k)); + Value lower = load(i16_ty, gep(ptr_ty, i16_ty, base, offset_lower)); + Value upper = load(i16_ty, gep(ptr_ty, i16_ty, base, offset_upper)); values.push_back(lower); values.push_back(upper); } } // Pack resulting values as LLVM struct. - Type structTy = struct_ty(SmallVector(values.size(), i16_ty)); + Type struct_ty = struct_ty(SmallVector(values.size(), i16_ty)); Value res = - packLLElements(loc, getTypeConverter(), values, rewriter, structTy); + packLLElements(loc, getTypeConverter(), values, rewriter, struct_ty); rewriter.replaceOp(op, res); return success(); @@ -489,8 +498,8 @@ class SparseLocalLoadToLLVMPass // we write the local load op to LLVM to have barriers in the right place. // See b/351986109. ModuleAllocation allocation(getOperation()); - ModuleMembarAnalysis membarPass(&allocation); - membarPass.run(); + ModuleMembarAnalysis membar_pass(&allocation); + membar_pass.run(); MLIRContext *context = &getContext(); ConversionTarget target(*context); From e8732e01b7e6bfcc793585f3d48ad72d45473624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Thu, 25 Jul 2024 05:21:41 -0700 Subject: [PATCH 148/376] Fix `llvm_compiler_test` debug compilation. When trying to build this test for LLDB debugging purposes, there is a linker error: `ld: error: duplicate symbol: main`. This error is only observed with `--dynamic_mode=off` bazel flag. Removing gunit from dependencies fixes that issue. PiperOrigin-RevId: 655918499 --- xla/tests/BUILD | 2 +- xla/tests/llvm_compiler_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 49ed0744e11310..8bea9b455720d8 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2642,10 +2642,10 @@ xla_test( "//xla/service:llvm_compiler", "//xla/stream_executor", "@com_google_absl//absl/status", - "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Core", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:env", + "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/tests/llvm_compiler_test.cc b/xla/tests/llvm_compiler_test.cc index 238482b650c3d6..bf0e52d59f55f3 100644 --- a/xla/tests/llvm_compiler_test.cc +++ b/xla/tests/llvm_compiler_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include #include "absl/status/status.h" #include "llvm/IR/Module.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -33,6 +32,7 @@ limitations under the License. #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/casts.h" +#include "tsl/platform/test.h" #include "tsl/platform/threadpool.h" namespace xla { From af3af5c0717ea1aa8e49a9eaf48ba49bb88e7ca1 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 25 Jul 2024 05:38:28 -0700 Subject: [PATCH 149/376] [XLA:GPU][MLIR-based emitters] Add a pattern that can shrink/refine bounds. Right now, it can refine ranges of dims/symbols based on the ranges of the induction variables. This is needed to update apply_indexing ops after the peeling is done. PiperOrigin-RevId: 655922249 --- xla/service/gpu/fusions/mlir/passes.h | 5 + .../gpu/fusions/mlir/simplify_affine.cc | 8 ++ .../gpu/fusions/mlir/simplify_arith.cc | 132 ++++++++++++------ .../fusions/mlir/tests/simplify_arith.mlir | 53 +++++++ .../gpu/fusions/reduction_mlir_test.cc | 2 +- 5 files changed, 157 insertions(+), 43 deletions(-) diff --git a/xla/service/gpu/fusions/mlir/passes.h b/xla/service/gpu/fusions/mlir/passes.h index 74481990af0021..fb42134231a9dd 100644 --- a/xla/service/gpu/fusions/mlir/passes.h +++ b/xla/service/gpu/fusions/mlir/passes.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "xla/service/gpu/model/indexing_map.h" @@ -31,6 +32,10 @@ namespace gpu { // Returns the range of a given value, if it can be statically determined. std::optional GetRange(mlir::Value value); +// Returns the range for the induction variable, if it can be statically +// determined. +std::optional GetIVRange(mlir::Value iv); + std::unique_ptr CreateEraseDeadFunctionsPass(); std::unique_ptr CreateExpandFloatOpsPass(bool pre_ampere); std::unique_ptr CreateConvertPureCallOpsPass(); diff --git a/xla/service/gpu/fusions/mlir/simplify_affine.cc b/xla/service/gpu/fusions/mlir/simplify_affine.cc index 4e78dea5704089..7b234998860d37 100644 --- a/xla/service/gpu/fusions/mlir/simplify_affine.cc +++ b/xla/service/gpu/fusions/mlir/simplify_affine.cc @@ -341,7 +341,15 @@ std::optional GetRange(mlir::Value value) { if (auto func_op = mlir::dyn_cast(parent)) { return attr_to_range(func_op.getArgAttr(bbarg.getArgNumber(), "xla.range")); } + return GetIVRange(value); +} +std::optional GetIVRange(mlir::Value iv) { + auto bbarg = mlir::dyn_cast(iv); + if (!bbarg) { + return std::nullopt; + } + auto parent = bbarg.getParentBlock()->getParentOp(); if (auto for_op = mlir::dyn_cast(parent)) { llvm::APInt lb, ub; if (mlir::matchPattern(for_op.getLowerBound(), mlir::m_ConstantInt(&lb)) && diff --git a/xla/service/gpu/fusions/mlir/simplify_arith.cc b/xla/service/gpu/fusions/mlir/simplify_arith.cc index 833d09bcdaf161..77b1d7c40ea290 100644 --- a/xla/service/gpu/fusions/mlir/simplify_arith.cc +++ b/xla/service/gpu/fusions/mlir/simplify_arith.cc @@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include + #include -#include #include #include #include #include -#include #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -31,44 +29,50 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/passes.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { namespace gpu { +namespace { #define GEN_PASS_DEF_SIMPLIFYARITHPASS #include "xla/service/gpu/fusions/mlir/passes.h.inc" -namespace { +using mlir::LogicalResult; +using mlir::OpRewritePattern; +using mlir::PatternRewriter; +using mlir::arith::CmpIOp; +using mlir::arith::CmpIPredicate; -Interval::ComparisonResult EvaluateCmpI(mlir::arith::CmpIPredicate pred, - Interval lhs, Interval rhs) { +Interval::ComparisonResult EvaluateCmpI(CmpIPredicate pred, Interval lhs, + Interval rhs) { switch (pred) { - case mlir::arith::CmpIPredicate::eq: + case CmpIPredicate::eq: return lhs.Eq(rhs); - case mlir::arith::CmpIPredicate::ne: + case CmpIPredicate::ne: return lhs.Ne(rhs); - case mlir::arith::CmpIPredicate::slt: - case mlir::arith::CmpIPredicate::ult: + case CmpIPredicate::slt: + case CmpIPredicate::ult: return lhs.Lt(rhs); - case mlir::arith::CmpIPredicate::sle: - case mlir::arith::CmpIPredicate::ule: + case CmpIPredicate::sle: + case CmpIPredicate::ule: return lhs.Le(rhs); - case mlir::arith::CmpIPredicate::sgt: - case mlir::arith::CmpIPredicate::ugt: + case CmpIPredicate::sgt: + case CmpIPredicate::ugt: return lhs.Gt(rhs); - case mlir::arith::CmpIPredicate::sge: - case mlir::arith::CmpIPredicate::uge: + case CmpIPredicate::sge: + case CmpIPredicate::uge: return lhs.Ge(rhs); } } -struct RewriteCmpI : mlir::OpRewritePattern { +struct RewriteCmpI : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite( - mlir::arith::CmpIOp op, mlir::PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(CmpIOp op, + PatternRewriter& rewriter) const override { auto rhs = GetRange(op.getRhs()); auto lhs = GetRange(op.getLhs()); if (!lhs || !rhs) { @@ -85,11 +89,11 @@ struct RewriteCmpI : mlir::OpRewritePattern { } }; -struct RewriteMaxSi : mlir::OpRewritePattern { +struct RewriteMaxSi : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite( - mlir::arith::MaxSIOp op, mlir::PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(mlir::arith::MaxSIOp op, + PatternRewriter& rewriter) const override { auto lhs = GetRange(op.getLhs()); auto rhs = GetRange(op.getRhs()); if (!lhs || !rhs) { @@ -106,11 +110,11 @@ struct RewriteMaxSi : mlir::OpRewritePattern { } }; -struct RewriteMinSi : mlir::OpRewritePattern { +struct RewriteMinSi : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite( - mlir::arith::MinSIOp op, mlir::PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(mlir::arith::MinSIOp op, + PatternRewriter& rewriter) const override { auto lhs = GetRange(op.getLhs()); auto rhs = GetRange(op.getRhs()); if (!lhs || !rhs) { @@ -157,11 +161,11 @@ mlir::Value FindNarrowestValueInChain(mlir::Value value) { // can be rewritten to shuffle-trunc-ext. If there is another copy of the // pattern afterwards, we can push the truncs/exts further down. template -struct RewriteTruncBitExt : mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; +struct RewriteTruncBitExt : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite( - Op op, mlir::PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(Op op, + PatternRewriter& rewriter) const override { mlir::Value lhs = FindNarrowestValueInChain(op.getLhs()); mlir::Value rhs = FindNarrowestValueInChain(op.getRhs()); @@ -198,12 +202,11 @@ struct RewriteTruncBitExt : mlir::OpRewritePattern { // Rewrites trunc-ext-shuffle to shuffle-trunc-ext. This pattern is designed to // work together with RewriteTruncBitExt to optimize pred reductions. -struct RewriteTruncExtShuffle - : public mlir::OpRewritePattern { +struct RewriteTruncExtShuffle : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite( - mlir::gpu::ShuffleOp op, mlir::PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(mlir::gpu::ShuffleOp op, + PatternRewriter& rewriter) const override { auto ext = op.getOperand(0).getDefiningOp(); if (!ext) { return rewriter.notifyMatchFailure(op, "no ext"); @@ -268,19 +271,64 @@ void AnnotateRanges(mlir::func::FuncOp func) { }); } +// Pattern to refine the bounds of an indexing map if some of its operands are +// bound, e.g. loop induction variables. +struct RefineConstraints : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, + PatternRewriter& rewriter) const override { + // Right now, we only handle loop induction variables, but other rules might + // be added. + IndexingMap indexing_map = indexing_op.getIndexingMap(); + int64_t dim_count = indexing_map.GetDimensionCount(); + bool updated_bounds = false; + for (mlir::OpOperand& operand : indexing_op->getOpOperands()) { + auto range = GetIVRange(operand.get()); + if (!range) { + continue; + } + auto operand_id = operand.getOperandNumber(); + Interval& current_interval = + operand_id < dim_count + ? indexing_map.GetMutableDimensionBound(operand_id) + : indexing_map.GetMutableSymbolBound(operand_id - dim_count); + if (!range->Contains(current_interval)) { + current_interval = current_interval.Intersect(*range); + updated_bounds = true; + } + } + if (!updated_bounds) { + return rewriter.notifyMatchFailure(indexing_op, "No bounds to refine"); + } + indexing_map.Simplify(); + rewriter.replaceOpWithNewOp( + indexing_op, indexing_op.getOperands(), indexing_map); + return mlir::success(); + } +}; + class SimplifyArithPass : public impl::SimplifyArithPassBase { public: void runOnOperation() override { - mlir::RewritePatternSet patterns(&getContext()); - AnnotateRanges(getOperation()); - patterns.add(&getContext()); - patterns - .add, - RewriteTruncBitExt, RewriteTruncExtShuffle>( - &getContext()); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + auto ctx = &getContext(); + auto func = getOperation(); + mlir::RewritePatternSet patterns(ctx); + AnnotateRanges(func); + // clang-format off + patterns.add< + RefineConstraints, + RewriteCmpI, + RewriteMaxSi, + RewriteMinSi, + RewriteTruncBitExt, + RewriteTruncBitExt, + RewriteTruncExtShuffle + >(ctx); + // clang-format on + if (mlir::failed( + mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) { signalPassFailure(); } } diff --git a/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir b/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir index 50b81dc5e2fadf..ee2e0ddbe29035 100644 --- a/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir +++ b/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir @@ -237,3 +237,56 @@ module { // CHECK-NEXT: extui // CHECK-NEXT: ori // CHECK-NEXT: return + +// ----- + +func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c42_f32 = arith.constant 42.0 : f32 + %loop = scf.for %i = %c0 to %c3 step %c1 + iter_args(%in_ = %tensor) -> (tensor<100xf32>) { + %0 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 4)> (%i in [0, 9]) + %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32> + scf.yield %updated :tensor<100xf32> + } + func.return %loop : tensor<100xf32> +} +// CHECK-LABEL: func.func @refine_constraints +// CHECK: %[[CST:.*]] = arith.constant 4.2 +// CHECK: scf.for +// CHECK: tensor.insert %[[CST]] + + +// ----- + +#map = affine_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)> +#map1 = affine_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)> +func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, + %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c73 = arith.constant 73 : index + %c42_f32 = arith.constant 42.0 : f32 + %th_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bl_x = gpu.block_id x {xla.range = [0 : index, 575 : index]} + %0 = scf.for %i = %c0 to %c73 step %c1 iter_args(%arg3 = %arg1) + -> (tensor<2400000x9xf32>) { + %2 = scf.for %j = %c0 to %c4 step %c1 iter_args(%arg5 = %arg3) + -> (tensor<2400000x9xf32>) { + %3 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 575]) + [%i in [0, 73], %j in [0, 3]] + %4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 575]) + [%j in [0, 3]] + %inserted = tensor.insert %c42_f32 into %arg5[%3, %4] + : tensor<2400000x9xf32> + scf.yield %inserted : tensor<2400000x9xf32> + } + scf.yield %2 : tensor<2400000x9xf32> + } + return %0 : tensor<2400000x9xf32> +} +// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> ((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768)> +// CHECK-LABEL: func.func @refine_constraints_for_symbol diff --git a/xla/service/gpu/fusions/reduction_mlir_test.cc b/xla/service/gpu/fusions/reduction_mlir_test.cc index c583088a245beb..6ba7431530309e 100644 --- a/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -408,7 +408,7 @@ TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) { // CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] // CHECK-NOT: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 1], %thread_id_x in [0, 255])[%[[I]] in [0, 4]] + // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 1], %thread_id_x in [0, 255])[%[[I]] in [0, 3]] // CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]]) // CHECK: scf.if // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]] in [0, 1], %thread_id_x in [0, 255]) From de94b1e2116b5aba98a290da8644af6a64678baf Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Thu, 25 Jul 2024 06:32:51 -0700 Subject: [PATCH 150/376] [XLA:GPU][Triton] Properly emit float to int conversions in ir_emitter_triton.cc PiperOrigin-RevId: 655934406 --- .../fusions/triton/triton_fusion_emitter.cc | 38 +++++++++++++++++- ...riton_fusion_emitter_device_legacy_test.cc | 40 +++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 5d3fa949307bfe..20fa19c0a9fac6 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -340,12 +340,46 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { } // float => int if (src_fp_element_ty && mlir::isa(dst_element_ty)) { - // TODO(b/266862493): Support unsigned integer types. if (dst_element_ty.isInteger(1)) { return b.create(ma::CmpFPredicate::UNE, value, ZerosLike(b, value)); } - return b.create(dst_ty, value); + // TODO(b/266862493): Support unsigned integer types. + // The current logic handles signed integer types only. Additional handling + // is needed for unsigned integer types. + auto cst_int = [&](int64_t x) { + if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { + return CreateConst(b, dst_element_ty, x, src_shaped_ty.getShape()); + } else { + return CreateConst(b, dst_element_ty, x); + } + }; + auto cst_float = [&](int64_t x) { + if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { + return CreateConst(b, src_fp_element_ty, x, src_shaped_ty.getShape()); + } else { + return CreateConst(b, src_fp_element_ty, x); + } + }; + auto fptosi = b.create(dst_ty, value); + int64_t min = llvm::minIntN(dst_element_ty.getIntOrFloatBitWidth()); + int64_t max = llvm::maxIntN(dst_element_ty.getIntOrFloatBitWidth()); + + // value <= static_cast(INT_MIN) ? INT_MIN : ... + auto clamped = b.create( + b.create(mlir::arith::CmpFPredicate::OLE, value, + cst_float(min)), + cst_int(min), fptosi); + // value >= static_cast(INT_MAX) ? INT_MAX : ... + clamped = b.create( + b.create(mlir::arith::CmpFPredicate::OGE, value, + cst_float(max)), + cst_int(max), clamped); + // isnan(value) ? 0 : ... + return b.create( + b.create(mlir::arith::CmpFPredicate::UNO, value, + value), + cst_int(0), clamped); } LOG(FATAL) << "Type conversion not supported: " diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 3eec1c4e934946..e88f9d784124a3 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -1183,6 +1183,46 @@ ENTRY e { )"); } +TEST_F(TritonTest, FloatToSignedIntConversion) { + const std::string kHloText = R"( +HloModule t, is_scheduled=true + +triton_gemm_r { + p_0 = s8[32,32]{1,0} parameter(0) + p_1 = f16[32,32]{1,0} parameter(1) + cvt_1 = s8[32,32]{1,0} convert(p_1) + ROOT r.1 = f32[32,32]{1,0} dot(p_0, cvt_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY e { + p_0 = s8[32,32]{1,0} parameter(0) + p_1 = f16[32,32]{1,0} parameter(1) + ROOT triton_gemm_r = f32[32,32]{1,0} fusion(p_0, p_1), kind=kCustom, + calls=triton_gemm_r, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":32,"block_n":32,"block_k":32, + "split_k":1,"num_stages":1,"num_warps":4, + "num_ctas":1}}} +})"; + TF_EXPECT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_gemm_r", R"( +CHECK: tt.func @triton_fn +CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> +CHECK-DAG: %[[FMIN:.*]] = arith.constant dense<-1.280000e+02> +CHECK-DAG: %[[IMIN:.*]] = arith.constant dense<-128> +CHECK-DAG: %[[FMAX:.*]] = arith.constant dense<1.270000e+02> +CHECK-DAG: %[[IMAX:.*]] = arith.constant dense<127> +CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[IN:.*]] : +CHECK: %[[CMP1:.*]] = arith.cmpf ole, %[[IN]], %[[FMIN]] +CHECK: %[[RES1:.*]] = arith.select %[[CMP1]], %[[IMIN]], %[[FPTOSI]] +CHECK: %[[CMP2:.*]] = arith.cmpf oge, %[[IN]], %[[FMAX]] +CHECK: %[[RES2:.*]] = arith.select %[[CMP2]], %[[IMAX]], %[[RES1]] +CHECK: %[[CMP3:.*]] = arith.cmpf uno, %[[IN]], %[[IN]] +CHECK: %[[RES3:.*]] = arith.select %[[CMP3]], %[[ZERO]], %[[RES2]] +})")); +} + TEST_F(TritonGemmTestWithoutTritonGemmAny, SkipF32F32) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; From eb70d5ad4a086d8f33d1f085681e59ade0ed7f73 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Thu, 25 Jul 2024 06:45:22 -0700 Subject: [PATCH 151/376] [XLA:GPU][NFC] Verify no changes to the instruction name in HloVerifier post-scheduling. This PR: 0. Adds scheduling_name to OpMetadata. This field is set in a separate pass in the subsequent CL. 1. Adds parsing of the new scheduling_name field to hlo_parser. 2. Adds verification logic to HloVerifier. (disabled by default) Currently this is NFC as nothing sets scheduling_name. PiperOrigin-RevId: 655936916 --- xla/hlo/ir/hlo_instruction.h | 3 ++ xla/hlo/ir/hlo_op_metadata.cc | 4 ++ xla/service/gpu/gpu_compiler.cc | 2 + xla/service/hlo_parser.cc | 6 +++ xla/service/hlo_parser_test.cc | 2 +- xla/service/hlo_verifier.cc | 26 +++++++++++ xla/service/hlo_verifier.h | 10 +++++ xla/service/hlo_verifier_test.cc | 75 ++++++++++++++++++++++++++++++++ xla/xla_data.proto | 3 ++ 9 files changed, 130 insertions(+), 1 deletion(-) diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index 0846593156003f..337e9ff534eb84 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -2198,6 +2198,9 @@ class HloInstruction { void set_metadata_preserve_layout(bool preserve_layout) { metadata_->set_preserve_layout(preserve_layout); } + void set_metadata_scheduling_name(const std::string& name) { + metadata_->set_scheduling_name(name); + } const OpMetadata& metadata() const { return *metadata_; } // Set/get the computation containing this instruction. set_parent should only diff --git a/xla/hlo/ir/hlo_op_metadata.cc b/xla/hlo/ir/hlo_op_metadata.cc index b45f23ba96d432..5370eed46e3765 100644 --- a/xla/hlo/ir/hlo_op_metadata.cc +++ b/xla/hlo/ir/hlo_op_metadata.cc @@ -61,6 +61,10 @@ std::string OpMetadataToString(const OpMetadata& metadata, bool only_op_name) { if (metadata.preserve_layout()) { result.push_back(absl::StrCat("preserve_layout=true")); } + if (!metadata.scheduling_name().empty()) { + result.push_back( + absl::StrCat("scheduling_name=\"", metadata.scheduling_name(), "\"")); + } return absl::StrJoin(result, " "); } diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 38751937725929..2304593b29af46 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -2479,6 +2479,8 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( pipeline.AddPass(); } + AddHloVerifier(&main_pipeline, + HloVerifierOpts{}.VerifyInstructionNameUnchanged()); return main_pipeline.Run(module).status(); } diff --git a/xla/service/hlo_parser.cc b/xla/service/hlo_parser.cc index 892683f0386042..ab144fa6eb34da 100644 --- a/xla/service/hlo_parser.cc +++ b/xla/service/hlo_parser.cc @@ -6235,6 +6235,7 @@ bool HloParserImpl::ParseMetadata(OpMetadata* metadata) { optional> profile_type; optional deduplicated_name; optional preserve_layout; + optional scheduling_name; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -6245,6 +6246,8 @@ bool HloParserImpl::ParseMetadata(OpMetadata* metadata) { &deduplicated_name}; attrs["preserve_layout"] = {/*required=*/false, AttrTy::kBool, &preserve_layout}; + attrs["scheduling_name"] = {/*required=*/false, AttrTy::kString, + &scheduling_name}; if (!ParseSubAttributes(attrs)) { return false; } @@ -6276,6 +6279,9 @@ bool HloParserImpl::ParseMetadata(OpMetadata* metadata) { } else { metadata->set_preserve_layout(false); } + if (scheduling_name) { + metadata->set_scheduling_name(*scheduling_name); + } return true; } diff --git a/xla/service/hlo_parser_test.cc b/xla/service/hlo_parser_test.cc index 054a188a3f7f45..1f50e26133b9e3 100644 --- a/xla/service/hlo_parser_test.cc +++ b/xla/service/hlo_parser_test.cc @@ -1375,7 +1375,7 @@ R"(HloModule test, entry_computation_layout={(f32[100]{0})->u32[100]{0}} ENTRY %test (p: f32[100]) -> u32[100] { %p = f32[100]{0} parameter(0) - ROOT %root = u32[100]{0} bitcast-convert(f32[100]{0} %p), metadata={op_type="a" op_name="b" source_file="c" source_line=1 profile_type={1} deduplicated_name="d"} + ROOT %root = u32[100]{0} bitcast-convert(f32[100]{0} %p), metadata={op_type="a" op_name="b" source_file="c" source_line=1 profile_type={1} deduplicated_name="d" scheduling_name="foo"} } )" diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 58d00559a8b738..ebf03de9e146b0 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -2077,6 +2077,30 @@ std::string ComputationsToString( }); } +absl::Status VerifyInstructionNameUnchanged(const HloModule& module, + const HloVerifierOpts& opts) { + if (!opts.verify_instruction_name_unchanged) { + return absl::OkStatus(); + } + for (auto* comp : module.computations()) { + for (auto* inst : comp->instructions()) { + if (inst->metadata().scheduling_name().empty()) { + continue; + } + // We do not enforce the invariant when the instruction has been cloned + // explicitly via .clone or .remat suffix. + if (inst->metadata().scheduling_name() != inst->name() && + (!absl::StrContains(inst->name(), ".remat") && + !absl::StrContains(inst->name(), ".clone"))) { + return absl::FailedPreconditionError(absl::StrCat( + "Expected instruction name to remain the same. Was '", + inst->metadata().scheduling_name(), "' is '", inst->name(), "'.")); + } + } + } + return absl::OkStatus(); +} + // Verifies various invariants about the structure of the HLO: // // (1) each instruction is non-null and has a non-null parent() set to the @@ -3001,6 +3025,8 @@ absl::StatusOr HloVerifier::Run( TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module)); TF_RETURN_IF_ERROR(VerifyChannels(*module)); + TF_RETURN_IF_ERROR(VerifyInstructionNameUnchanged( + *module, target_metadata_->GetVerifierOpts())); std::unique_ptr shape_verifier = target_metadata_->GetVerifier(); diff --git a/xla/service/hlo_verifier.h b/xla/service/hlo_verifier.h index b2135e31e75a10..2343ad5fcaad74 100644 --- a/xla/service/hlo_verifier.h +++ b/xla/service/hlo_verifier.h @@ -96,6 +96,11 @@ struct HloVerifierOpts { return std::move(*this); } + HloVerifierOpts&& VerifyInstructionNameUnchanged() { + verify_instruction_name_unchanged = true; + return std::move(*this); + } + bool IsLayoutSensitive() const { return layout_sensitive; } bool AllowMixedPrecision() const { return allow_mixed_precision; } @@ -139,6 +144,11 @@ struct HloVerifierOpts { // Whether unbounded dynamic sizes should be allowed for shapes. bool allow_unbounded_dynamism = false; + // Check whether instruction has been renamed. + // Should enforce no function renames unless the name instruction has been + // cloned (".clone" suffix) or rematted (".remat"); + bool verify_instruction_name_unchanged = false; + HloPredicate instruction_can_change_layout; // Returns a target-specific shape size. diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index e648e1bdf49f60..874077fce3e6b2 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -2503,6 +2503,81 @@ ENTRY computation { .status()); } +TEST_F(HloVerifierTest, VerifyInstructionNameChanged) { + const char* const hlo = R"( +HloModule module + +ENTRY computation { + p0 = f32[32] parameter(0), metadata={scheduling_name="p0"} + p1 = f32[32] parameter(1), metadata={scheduling_name="p1"} + ROOT add0 = f32[32] add(p0,p1), metadata={scheduling_name="add_changed"} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + auto status = HloVerifier{HloVerifierOpts{}.VerifyInstructionNameUnchanged()} + .Run(module.get()) + .status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("Expected instruction name to remain the same.")); +} + +TEST_F(HloVerifierTest, VerifyInstructionNameUnchanged) { + const char* const hlo = R"( +HloModule module + +ENTRY computation { + p0 = f32[32] parameter(0), metadata={scheduling_name="p0"} + p1 = f32[32] parameter(1), metadata={scheduling_name="p1"} + ROOT add0 = f32[32] add(p0,p1), metadata={scheduling_name="add0"} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + TF_ASSERT_OK(HloVerifier{HloVerifierOpts{}.VerifyInstructionNameUnchanged()} + .Run(module.get()) + .status()); +} + +TEST_F(HloVerifierTest, VerifyInstructionNameSchedulingNameNotPresent) { + const char* const hlo = R"( +HloModule module + +ENTRY computation { + p0 = f32[32] parameter(0) + p1 = f32[32] parameter(1) + ROOT add0 = f32[32] add(p0,p1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + TF_ASSERT_OK(HloVerifier{HloVerifierOpts{}.VerifyInstructionNameUnchanged()} + .Run(module.get()) + .status()); +} + +TEST_F(HloVerifierTest, VerifyInstructionNameChangedOkWithRematAndClones) { + const char* const hlo = R"( +HloModule module + +ENTRY computation { + p0 = f32[32] parameter(0), metadata={scheduling_name="p0"} + p1 = f32[32] parameter(1), metadata={scheduling_name="p1"} + add0.remat = f32[32] add(p0,p1), metadata={scheduling_name="add0"} + ROOT add1.clone = f32[32] add(add0.remat, p0), metadata={scheduling_name="add1"} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + auto status = HloVerifier{HloVerifierOpts{}.VerifyInstructionNameUnchanged()} + .Run(module.get()) + .status(); + TF_ASSERT_OK(HloVerifier{HloVerifierOpts{}.VerifyInstructionNameUnchanged()} + .Run(module.get()) + .status()); +} + TEST_F(HloVerifierTest, ReshapeIsNotBitcast) { const char* const hlo = R"( HloModule Module diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 113c9cebc55abe..4c7d47b1bf66b9 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -439,6 +439,9 @@ message OpMetadata { // Ids are 1-based to keep 0 value as representation of non-set property. int32 stack_frame_id = 15; + // Instruction name available upon scheduling. + string scheduling_name = 16; + reserved 14; } From 6320c821ca9e58e036b14dae15cbe41a8fa94ef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Thu, 25 Jul 2024 07:03:58 -0700 Subject: [PATCH 152/376] [XLA:CPU] Test `dynamic-reshape` op in thunks runtime. `dynamic-reshape` op is already supported in thunks runtime, because it is rewritten as other ops. This CL adds tests covering dynamic reshape basic functionality, and turns on related tests. PiperOrigin-RevId: 655941670 --- xla/service/BUILD | 1 + xla/tests/BUILD | 15 +++ xla/tests/dynamic_reshape_test.cc | 169 ++++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+) create mode 100644 xla/tests/dynamic_reshape_test.cc diff --git a/xla/service/BUILD b/xla/service/BUILD index 8775d81790d1a0..ed7e8303f3b5f6 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -4205,6 +4205,7 @@ cc_library( xla_test( name = "dynamic_padder_test", srcs = ["dynamic_padder_test.cc"], + tags = ["test_xla_cpu_thunks"], deps = [ ":algebraic_simplifier", ":dynamic_dimension_inference", diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 8bea9b455720d8..643d44201a8b30 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2121,6 +2121,21 @@ xla_test( ], ) +xla_test( + name = "dynamic_reshape_test", + srcs = ["dynamic_reshape_test.cc"], + disabled_backends = ["interpreter"], + tags = ["test_xla_cpu_thunks"], + deps = [ + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep + "//xla:literal", + "//xla:literal_util", + "//xla:test", + "@tsl//tsl/platform:statusor", + ], +) + xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], diff --git a/xla/tests/dynamic_reshape_test.cc b/xla/tests/dynamic_reshape_test.cc new file mode 100644 index 00000000000000..efa4cfb213a659 --- /dev/null +++ b/xla/tests/dynamic_reshape_test.cc @@ -0,0 +1,169 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_macros.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +class DynamicReshapeTest : public HloTestBase {}; + +// TODO(b/355402228): Enable this test once the bug is fixed. +TEST_F(DynamicReshapeTest, DISABLED_ON_GPU(SingleDynamicDimension)) { + constexpr const char* kModuleStr = R"( + HloModule DynamicReshapeTest.SingleDynamicDimension + + ENTRY main { + param = s32[2, 3, 3] parameter(0) + two = s32[] parameter(1) + param_padded = s32[2, <=3, 3] set-dimension-size(param, two), + dimensions={1} + nine = s32[] parameter(2) + ROOT reshaped = s32[<=18] dynamic-reshape(param_padded, nine) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR3( + {{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, + {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}}}); + Literal arg1 = LiteralUtil::CreateR0(2); + Literal arg2 = LiteralUtil::CreateR0(9); + + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {&arg0, &arg1, &arg2})); + + Literal expected = + LiteralUtil::CreateR1({0, 1, 2, 3, 4, 5, 9, 10, 11}); + EXPECT_EQ(result, expected); +} + +// TODO(b/355402228): Enable this test once the bug is fixed. +TEST_F(DynamicReshapeTest, DISABLED_ON_GPU(DoubleDynamicDimensions)) { + constexpr const char* kModuleStr = R"( + HloModule DynamicReshapeTest.DoubleDynamicDimensions + + ENTRY main { + param = s32[2, 3, 3] parameter(0) + two = s32[] parameter(1) + param_padded_partial = s32[2, <=3, 3] set-dimension-size(param, two), + dimensions={1} + param_padded = s32[2, <=3, <=3] set-dimension-size(param_padded_partial, + two), dimensions={2} + eight = s32[] parameter(2) + ROOT reshaped = s32[<=18] dynamic-reshape(param_padded, eight) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR3( + {{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, + {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}}}); + Literal arg1 = LiteralUtil::CreateR0(2); + Literal arg2 = LiteralUtil::CreateR0(8); + + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {&arg0, &arg1, &arg2})); + + Literal expected = + LiteralUtil::CreateR1({0, 1, 3, 4, 9, 10, 12, 13}); + EXPECT_EQ(result, expected); +} + +TEST_F(DynamicReshapeTest, OutputDoubleDynamicDimensions) { + constexpr const char* kModuleStr = R"( + HloModule DynamicReshapeTest.OutputDoubleDynamicDimensions + + ENTRY main { + param = s32[18] parameter(0) + eight = s32[] parameter(1) + param_dynamic = s32[<=18] set-dimension-size(param, eight), dimensions={0} + two = s32[] parameter(2) + ROOT reshaped = s32[2, <=3, <=3] dynamic-reshape(param_dynamic, two, two, + two) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR1( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); + Literal arg1 = LiteralUtil::CreateR0(8); + Literal arg2 = LiteralUtil::CreateR0(2); + + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {&arg0, &arg1, &arg2})); + + Literal expected = + LiteralUtil::CreateR3({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}); + EXPECT_EQ(result, expected); +} + +// TODO(b/355402228): Enable this test once the bug is fixed. +TEST_F(DynamicReshapeTest, DISABLED_ON_GPU(Complicated)) { + constexpr const char* kModuleStr = R"( + HloModule DynamicReshapeTest.Complicated + + ENTRY main { + param = s32[3, 4, 4] parameter(0) + two = s32[] parameter(1) + param_dynamic = s32[<=3, 4, 4] set-dimension-size(param, two), + dimensions={0} + three = s32[] parameter(2) + param_dynamic1 = s32[<=3, <=4, 4] set-dimension-size( + param_dynamic, three), dimensions={1} + param_dynamic2 = s32[<=3, <=4, <=4] set-dimension-size( + param_dynamic1, three), dimensions={2} + six = s32[] parameter(3) + + // Static reshape is from [3, 4, 4] to [6, 8]. + // Dynamic reshape is from [2, 3, 3] to [3, 6]. + ROOT reshaped = s32[<=6, <=8] dynamic-reshape(param_dynamic2, three, six) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR3( + {{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}}, + {{16, 17, 18, 19}, {20, 21, 22, 23}, {24, 25, 26, 27}, {28, 29, 30, 31}}, + {{32, 33, 34, 35}, + {36, 37, 38, 39}, + {40, 41, 42, 43}, + {44, 45, 46, 47}}}); + Literal arg1 = LiteralUtil::CreateR0(2); + Literal arg2 = LiteralUtil::CreateR0(3); + Literal arg3 = LiteralUtil::CreateR0(6); + + TF_ASSERT_OK_AND_ASSIGN( + auto result, Execute(std::move(module), {&arg0, &arg1, &arg2, &arg3})); + + Literal expected = LiteralUtil::CreateR2( + {{0, 1, 2, 4, 5, 6}, {8, 9, 10, 16, 17, 18}, {20, 21, 22, 24, 25, 26}}); + EXPECT_EQ(result, expected); +} + +} // namespace +} // namespace xla From 43106c31740b4fda4ad4f35ca38efd8e6647a995 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Thu, 25 Jul 2024 07:40:32 -0700 Subject: [PATCH 153/376] [Triton] Add a test for BF16 to FP8 conversion in Triton fusion emitter. PiperOrigin-RevId: 655950800 --- ...riton_fusion_emitter_device_legacy_test.cc | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index e88f9d784124a3..9a2d2d6dbb7520 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -4742,8 +4742,7 @@ CHECK-NOT: inputPrecision = tf32 } TEST_F(TritonTest, Fp8LoweringIsSupportedPostHopper) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::HOPPER)) { + if (!GetCudaComputeCapability().IsAtLeastHopper()) { GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; } const std::string hlo_text = R"( @@ -4776,6 +4775,36 @@ CHECK: tt.dot {{.*}}{maxNumImpreciseAcc = 2147483647 : i32} : tensor<128x64xf8E4 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); } +TEST_F(TritonTest, BF16ToFP8EndToEnd) { + if (!GetCudaComputeCapability().IsAtLeastHopper()) { + GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs."; + } + + const std::string hlo_text = R"( +HloModule t + +triton_dot { + parameter_0 = bf16[32,32]{1,0} parameter(0) + parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) + convert = f8e4m3fn[32,32]{1,0} convert(parameter_0) + ROOT dot = f32[32,32]{1,0} dot(convert, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={1} +} + +ENTRY main { + parameter_0 = bf16[32,32]{1,0} parameter(0) + parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1) + ROOT gemm_fusion_dot = f32[32,32]{1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_dot, + backend_config={ + "fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config": + {"block_m":"32","block_n":"32","block_k":"32","split_k":"1", + "num_stages":"1","num_warps":"4","num_ctas":"1"}}} +})"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3})); +} + // Test PreventMmaV3LoopUnrolling pass in order to keep compile time low. // See b/344841434. TEST_F(TritonGemmTest, TestPreventMMAV3LoopUnrolling) { From 330a06ba2b7079c15d6436eafa7ec1406fd0a479 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 25 Jul 2024 07:43:17 -0700 Subject: [PATCH 154/376] [XLA_GPU][MLIR-based emitters] Add a pass to flatten tensors. Right now it only supports tensor.extract, tensor.insert, xla_gpu.atomic_rmw, func.func and func.return, scf.for, scf.if and scf.yield. PiperOrigin-RevId: 655951526 --- xla/service/gpu/fusions/mlir/BUILD | 1 + .../gpu/fusions/mlir/flatten_tensors.cc | 452 ++++++++++++++++++ .../gpu/fusions/mlir/ir/xla_gpu_ops.cc | 3 +- xla/service/gpu/fusions/mlir/passes.h | 1 + xla/service/gpu/fusions/mlir/passes.td | 15 + .../fusions/mlir/tests/flatten_tensors.mlir | 140 ++++++ 6 files changed, 610 insertions(+), 2 deletions(-) create mode 100644 xla/service/gpu/fusions/mlir/flatten_tensors.cc create mode 100644 xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir diff --git a/xla/service/gpu/fusions/mlir/BUILD b/xla/service/gpu/fusions/mlir/BUILD index b44d7173538d98..9b603576488db4 100644 --- a/xla/service/gpu/fusions/mlir/BUILD +++ b/xla/service/gpu/fusions/mlir/BUILD @@ -284,6 +284,7 @@ cc_library( "convert_xla_gpu_pure_call_ops.cc", "erase_dead_functions.cc", "expand_float_ops.cc", + "flatten_tensors.cc", "lower_tensors.cc", "lower_to_llvm.cc", "lower_xla_gpu_to_scf.cc", diff --git a/xla/service/gpu/fusions/mlir/flatten_tensors.cc b/xla/service/gpu/fusions/mlir/flatten_tensors.cc new file mode 100644 index 00000000000000..99a7ecb7c57113 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/flatten_tensors.cc @@ -0,0 +1,452 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/layout_util.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/shape_util.h" + +namespace xla { +namespace gpu { +namespace { + +#define GEN_PASS_DEF_FLATTENTENSORSPASS +#include "xla/service/gpu/fusions/mlir/passes.h.inc" + +using mlir::Location; +using mlir::LogicalResult; +using mlir::MLIRContext; +using mlir::OpRewritePattern; +using mlir::PatternRewriter; +using mlir::RankedTensorType; +using mlir::SmallVector; +using mlir::Type; +using mlir::TypedValue; +using mlir::TypeRange; +using mlir::UnrealizedConversionCastOp; +using mlir::Value; +using mlir::ValueRange; +using mlir::func::FuncOp; +using mlir::func::ReturnOp; +using mlir::scf::ForOp; +using mlir::scf::IfOp; +using mlir::tensor::ExtractOp; +using mlir::tensor::InsertOp; + +RankedTensorType GetFlattenedType(RankedTensorType tensor_type) { + return RankedTensorType::get({tensor_type.getNumElements()}, + tensor_type.getElementType()); +} + +bool HasOnlyFlatTensorsOrScalars(TypeRange types) { + return llvm::all_of(types, [](Type ty) { + auto tensor_type = mlir::dyn_cast(ty); + if (!tensor_type) return true; + return tensor_type.getRank() < 2; + }); +} + +struct RewriteFunctionSignatures : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FuncOp op, + PatternRewriter& rewriter) const override { + auto input_types = op.getFunctionType().getInputs(); + auto result_types = op.getFunctionType().getResults(); + if (HasOnlyFlatTensorsOrScalars(input_types) && + HasOnlyFlatTensorsOrScalars(result_types)) { + return rewriter.notifyMatchFailure(op, "nothing to flatten"); + } + + auto loc = op.getLoc(); + mlir::Block* entry_block = &op.getBody().front(); + SmallVector new_result_types; + SmallVector new_results; + + // If some results are tensors, we need to flatten them. + auto terminator = entry_block->getTerminator(); + rewriter.setInsertionPoint(terminator); + + for (Value result : terminator->getOperands()) { + auto tensor_type = mlir::dyn_cast(result.getType()); + if (!tensor_type) { + new_result_types.push_back(result.getType()); + new_results.push_back(result); + continue; + } + auto new_result_type = GetFlattenedType(tensor_type); + new_result_types.push_back(new_result_type); + + Value result_1d = + rewriter + .create(loc, new_result_type, result) + .getResult(0); + new_results.push_back(result_1d); + } + rewriter.replaceOpWithNewOp(terminator, new_results); + + // Cast all function arguments to the original type. + SmallVector new_operand_types(input_types); + rewriter.setInsertionPointToStart(entry_block); + for (auto&& [index, operand_type] : llvm::enumerate(new_operand_types)) { + if (auto tensor_type = mlir::dyn_cast(operand_type)) { + if (tensor_type.getRank() > 1) { + mlir::BlockArgument func_argument = op.getArgument(index); + auto cast_to_orig_type = rewriter.create( + loc, operand_type, func_argument); + func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0), + cast_to_orig_type); + operand_type = GetFlattenedType(tensor_type); + } + } + } + // Replace the function arguments with the new types. + for (auto [arg, arg_type] : + llvm::zip(entry_block->getArguments(), new_operand_types)) { + arg.setType(arg_type); + } + // Update function signature. + op.setType(rewriter.getFunctionType(new_operand_types, new_result_types)); + return mlir::success(); + } +}; + +// Returns the linearized index, if the rank is greater than 1. Otherwise, +// returns nullptr. +Value LinearizeIndex(TypedValue tensor, + ValueRange indices, PatternRewriter& rewriter) { + if (tensor.getType().getRank() < 2) { + return nullptr; + } + auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape()); + if (auto encoding = tensor.getType().getEncoding()) { + *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector( + mlir::cast(encoding).getValues())); + } + auto linear_shape = + ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)}); + auto linearized_map = + GetBitcastMap(byte_shape, linear_shape, tensor.getContext()); + mlir::SmallVector result; + rewriter.createOrFold(result, tensor.getLoc(), indices, + ValueRange{}, linearized_map); + return result.front(); +} + +struct RewriteTensorExtract : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp op, + PatternRewriter& rewriter) const override { + auto tensor = op.getTensor(); + auto tensor_type = tensor.getType(); + auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter); + if (linear_index == nullptr) { + return rewriter.notifyMatchFailure(op, "the tensor is already flat"); + } + auto tensor_1D = rewriter + .create( + op.getLoc(), GetFlattenedType(tensor_type), tensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, tensor_1D, linear_index); + return mlir::success(); + } +}; + +struct RewriteTensorInsert : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertOp op, + PatternRewriter& rewriter) const override { + auto tensor = op.getDest(); + auto tensor_type = tensor.getType(); + auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter); + if (linear_index == nullptr) { + return rewriter.notifyMatchFailure(op, "the tensor is already flat"); + } + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto tensor_1D = b.create( + GetFlattenedType(tensor_type), tensor) + .getResult(0); + auto new_insert = + b.create(op.getScalar(), tensor_1D, linear_index); + auto cast_to_orig_type = b.create( + tensor_type, new_insert.getResult()); + rewriter.replaceOp(op, cast_to_orig_type.getResult(0)); + return mlir::success(); + } +}; + +struct RewriteAtomicRMW : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AtomicRMWOp op, + PatternRewriter& rewriter) const override { + auto tensor = op.getInput(); + auto tensor_type = tensor.getType(); + auto linear_index = LinearizeIndex(tensor, op.getIndices(), rewriter); + if (linear_index == nullptr) { + return rewriter.notifyMatchFailure(op, "the tensor is already flat"); + } + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto tensor_1D = b.create( + GetFlattenedType(tensor_type), tensor) + .getResult(0); + auto new_atomic_rmw = b.create(tensor_1D, linear_index); + rewriter.inlineRegionBefore(op.getRegion(), + &new_atomic_rmw.getRegion().front()); + auto cast_to_orig_type = b.create( + tensor_type, new_atomic_rmw.getResult()); + rewriter.replaceOp(op, cast_to_orig_type.getResult(0)); + return mlir::success(); + } +}; + +// Checks that the value is produced by an unrealized conversion cast from 1D +// tensor to ND. Returns the 1D tensor if so. +std::optional GetDelinearizedTensor(Value value) { + auto tensor_type = mlir::dyn_cast(value.getType()); + if (!tensor_type || tensor_type.getRank() < 2) { + return std::nullopt; + } + auto cast = value.getDefiningOp(); + if (!cast || cast->getNumResults() != 1 || cast->getNumOperands() != 1) { + return std::nullopt; + } + auto type_before_linearization = + mlir::dyn_cast(cast->getOperand(0).getType()); + if (!type_before_linearization || type_before_linearization.getRank() != 1) { + return std::nullopt; + } + return cast->getOperand(0); +} + +struct RewriteForOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp op, + PatternRewriter& rewriter) const override { + llvm::SmallBitVector args_to_update(op.getNumResults(), false); + mlir::SmallVector new_init_args; + new_init_args.reserve(op.getNumResults()); + for (auto [index, arg] : llvm::enumerate(op.getInitArgs())) { + auto type_before_linearization = GetDelinearizedTensor(arg); + if (!type_before_linearization.has_value()) { + new_init_args.push_back(arg); + continue; + } + new_init_args.push_back(*type_before_linearization); + args_to_update.set(index); + } + if (args_to_update.none()) { + return rewriter.notifyMatchFailure(op, "no args to update"); + } + // Create new ForOp with updated init args. + Location loc = op.getLoc(); + auto new_for_op = + rewriter.create(loc, op.getLowerBound(), op.getUpperBound(), + op.getStep(), new_init_args); + new_for_op->setAttrs(op->getAttrs()); + + // Insert casts for the block arguments. + mlir::Block* new_body = new_for_op.getBody(); + mlir::Block* old_body = op.getBody(); + rewriter.setInsertionPoint(new_body, new_body->begin()); + SmallVector updated_block_args{new_body->getArguments().begin(), + new_body->getArguments().end()}; + for (auto [index, arg] : + llvm::enumerate(new_body->getArguments().drop_front())) { + if (!args_to_update.test(index)) continue; + updated_block_args[index + 1] = + rewriter + .create( + loc, old_body->getArgument(index + 1).getType(), arg) + .getResult(0); + } + + // Move the body of the old ForOp to the new one. + rewriter.mergeBlocks(old_body, new_body, updated_block_args); + + // Update the terminator. + auto new_terminator = + mlir::cast(new_body->getTerminator()); + rewriter.setInsertionPoint(new_terminator); + for (auto&& [index, yielded_value] : + llvm::enumerate(new_terminator.getResultsMutable())) { + if (!args_to_update.test(index)) continue; + yielded_value.assign( + rewriter + .create( + loc, new_init_args[index].getType(), yielded_value.get()) + .getResult(0)); + } + + // Cast back the results. + rewriter.setInsertionPointAfter(new_for_op); + SmallVector new_results(new_for_op.getResults()); + for (auto&& [index, result] : llvm::enumerate(new_results)) { + if (!args_to_update.test(index)) continue; + result = rewriter + .create( + loc, op->getResult(index).getType(), result) + .getResult(0); + } + rewriter.replaceOp(op, new_results); + return mlir::failure(); + } +}; + +struct RewriteIfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter& rewriter) const override { + auto result_types = op.getResultTypes(); + if (HasOnlyFlatTensorsOrScalars(result_types)) { + return rewriter.notifyMatchFailure(op, "nothing to flatten"); + } + mlir::scf::YieldOp then_yield = op.thenYield(); + SmallVector new_result_types; + new_result_types.reserve(then_yield.getNumOperands()); + bool found_cast = false; + for (auto& result : then_yield->getOpOperands()) { + auto delinearized_tensor = GetDelinearizedTensor(result.get()); + if (!delinearized_tensor.has_value()) { + new_result_types.push_back(result.get().getType()); + continue; + } + new_result_types.push_back(delinearized_tensor->getType()); + result.set(*delinearized_tensor); + found_cast = true; + } + if (!found_cast) { + return rewriter.notifyMatchFailure(op, "no cast found"); + } + Location loc = op.getLoc(); + // Update the else branch if present. + bool has_else_region = !op.getElseRegion().empty(); + if (has_else_region) { + mlir::scf::YieldOp else_yield = op.elseYield(); + mlir::OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(else_yield); + for (auto&& [result, type] : + llvm::zip(else_yield->getOpOperands(), new_result_types)) { + if (result.get().getType() == type) continue; + result.set( + rewriter.create(loc, type, result.get()) + .getResult(0)); + } + } + // Create new IfOp and move the old op's regions to the new one. + auto new_if_op = rewriter.create(loc, new_result_types, + op.getCondition(), has_else_region); + rewriter.inlineRegionBefore(op.getThenRegion(), + &new_if_op.getThenRegion().back()); + rewriter.eraseBlock(&new_if_op.getThenRegion().back()); + if (has_else_region) { + rewriter.inlineRegionBefore(op.getElseRegion(), + &new_if_op.getElseRegion().back()); + rewriter.eraseBlock(&new_if_op.getElseRegion().back()); + } + + // Update the results. + rewriter.setInsertionPointAfter(new_if_op); + SmallVector new_results(new_if_op.getResults()); + for (auto&& [index, result] : llvm::enumerate(new_results)) { + Type old_type = op->getResult(index).getType(); + if (result.getType() == old_type) continue; + result = + rewriter.create(loc, old_type, result) + .getResult(0); + } + rewriter.replaceOp(op, new_results); + return mlir::success(); + } +}; + +class FlattenTensorsPass + : public impl::FlattenTensorsPassBase { + public: + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + MLIRContext* mlir_context = &getContext(); + mlir::RewritePatternSet patterns(mlir_context); + // clang-format off + patterns.add< + RewriteAtomicRMW, + RewriteForOp, + RewriteFunctionSignatures, + RewriteIfOp, + RewriteTensorExtract, + RewriteTensorInsert + >(mlir_context); + // clang-format on + ApplyIndexingOp::getCanonicalizationPatterns(patterns, mlir_context); + if (mlir::failed( + mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + signalPassFailure(); + return; + } + // Check if there are no unrealized_conversion_casts. + bool module_has_casts = module + .walk([](UnrealizedConversionCastOp op) { + return mlir::WalkResult::interrupt(); + }) + .wasInterrupted(); + if (module_has_casts) { + llvm::outs() << "FlattenTensorsPass failed to converge"; + signalPassFailure(); + return; + } + } +}; + +} // namespace + +std::unique_ptr CreateFlattenTensorsPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc index 6f6662686553ea..c7a0575a0ef087 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -613,10 +613,9 @@ struct FoldApplyIndexingResults new_exprs.reserve(num_results); SmallVector new_values; new_values.reserve(num_results); - Value zero = rewriter.create(loc, 0); for (mlir::OpResult opresult : indexing_op->getOpResults()) { if (opresult.use_empty()) { - new_values.push_back(zero); + new_values.push_back(rewriter.create(loc, 0)); continue; } diff --git a/xla/service/gpu/fusions/mlir/passes.h b/xla/service/gpu/fusions/mlir/passes.h index fb42134231a9dd..bb0f1d44380018 100644 --- a/xla/service/gpu/fusions/mlir/passes.h +++ b/xla/service/gpu/fusions/mlir/passes.h @@ -39,6 +39,7 @@ std::optional GetIVRange(mlir::Value iv); std::unique_ptr CreateEraseDeadFunctionsPass(); std::unique_ptr CreateExpandFloatOpsPass(bool pre_ampere); std::unique_ptr CreateConvertPureCallOpsPass(); +std::unique_ptr CreateFlattenTensorsPass(); std::unique_ptr CreateLowerTensorsPass( bool is_amd_gpu = false, const std::string& gpu_arch = "6.0"); std::unique_ptr CreateLowerToLLVMPass(); diff --git a/xla/service/gpu/fusions/mlir/passes.td b/xla/service/gpu/fusions/mlir/passes.td index 66d86e51fcea14..6785670581d68e 100644 --- a/xla/service/gpu/fusions/mlir/passes.td +++ b/xla/service/gpu/fusions/mlir/passes.td @@ -49,6 +49,21 @@ def ConvertPureCallOpsPass let constructor = "CreateConvertPureCallOpsPass()"; } +def FlattenTensorsPass : Pass<"xla-gpu-flatten-tensors", "mlir::ModuleOp"> { + let summary = "Flatten tensors."; + + let description = [{ + Linearizes all tensors loads and stores. + }]; + + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::tensor::TensorDialect", + "xla::gpu::XlaGpuDialect", + ]; + let constructor = "CreateFlattenTensorsPass()"; +} + def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> { let summary = "Lowers tensors to llvm pointers and loads/stores."; diff --git a/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir b/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir new file mode 100644 index 00000000000000..ee2c2ae9e9553d --- /dev/null +++ b/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir @@ -0,0 +1,140 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-flatten-tensors \ +// RUN: --verify-diagnostics | FileCheck %s + +func.func @tensor_extract( + %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>, + %arg1: index, %arg2: index) -> f32 { + %v = tensor.extract %arg0[%arg1, %arg2] + : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> + func.return %v : f32 +} +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1 * 2 + d0)> + +// CHECK-LABEL: func.func @tensor_extract( +// CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>, +// CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index) -> f32 { +// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]] +// CHECK-SAME: in [0, 1], %[[J]] in [0, 2]) +// CHECK: tensor.extract %[[SRC]][%[[INDEX]]] : tensor<6xf32> + +// ----- + +func.func @tensor_insert( + %arg0: tensor<10x24xcomplex>) -> tensor<10x24xcomplex> { + %c1 = arith.constant 1 : index + %real = arith.constant 3.0 : f32 + %imag = arith.constant 2.0 : f32 + %complex = complex.create %real, %imag : complex + %out = tensor.insert %complex into %arg0[%c1, %c1] : tensor<10x24xcomplex> + func.return %out : tensor<10x24xcomplex> +} +// CHECK-LABEL: func.func @tensor_insert( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<240xcomplex>) -> tensor<240xcomplex> { +// CHECK: %[[INDEX:.*]] = arith.constant 25 +// CHECK: %[[COMPLEX:.*]] = complex.create +// CHECK: tensor.insert %[[COMPLEX]] into %[[TENSOR]][%[[INDEX]]] +// CHECK-SAME: : tensor<240xcomplex> + +// ----- + +func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) + -> (tensor<2x4xf32>) { + %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { + ^bb0(%current : f32): + %c42 = arith.constant 1.0 : f32 + %add = arith.minimumf %current, %c42 : f32 + xla_gpu.yield %add : f32 + } + return %ret : tensor<2x4xf32> +} +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)> + +// CHECK-LABEL: func.func @atomic_rmw( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index, +// CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> { +// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]] +// CHECK-SAME: (%[[I]] in [0, 1], %[[J]] in [0, 3]) +// CHECK: xla_gpu.atomic_rmw %[[TENSOR]][%[[INDEX]]] : tensor<8xf32> + +// ----- + +func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) + -> (tensor<32x1024xf32>, tensor<64x8x4xf32>, f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_f32 = arith.constant 0.0 : f32 + %for:2 = scf.for %i = %c0 to %c64 step %c32 iter_args(%t0_ = %t0, %t1_ = %t1) + -> (tensor<32x1024xf32>, tensor<64x8x4xf32>) { + %update0 = tensor.insert %c0_f32 into %t0_[%c1, %i] : tensor<32x1024xf32> + %update1 = tensor.insert %c0_f32 into %t1_[%i, %c1, %c1] : tensor<64x8x4xf32> + scf.yield %update0, %update1 : tensor<32x1024xf32>, tensor<64x8x4xf32> + } {some_attr} + return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32 +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 + 1024)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 32 + 5)> +// CHECK-LABEL: func.func @for_loop( +// CHECK-SAME: %[[T0:.*]]: tensor<32768xf32>, +// CHECK-SAME: %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) { + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[F32:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[FOR:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[C64]] +// CHECK-SAME: step %[[C32]] +// CHECK-SAME: iter_args(%[[T0_:.*]] = %[[T0]], %[[T1_:.*]] = %[[T1]]) +// CHECK: %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]] in [0, 1023]) +// CHECK: %[[UPD0:.*]] = tensor.insert %[[F32]] into %[[T0_]][%[[IND0]]] +// CHECK: %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]] in [0, 63]) +// CHECK: %[[UPD1:.*]] = tensor.insert %[[F32]] into %[[T1_]][%[[IND1]]] +// CHECK: scf.yield %[[UPD0]], %[[UPD1]] : tensor<32768xf32>, tensor<2048xf32> + +// ----- + +#map = affine_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36)> +#map1 = affine_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4)> +#map2 = affine_map<(d0, d1) -> ((d1 * 128 + d0) mod 9)> +func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, + %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>) + -> tensor<4000x4x9xf32> { + %c0 = arith.constant 0 : index + %c3999 = arith.constant 3999 : index + %th_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bl_x = gpu.block_id x {xla.range = [0 : index, 393749 : index]} + %0 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 393749]) + %extracted = tensor.extract %arg1[%0, %c0] : tensor<1400x1xi32> + %1 = arith.index_cast %extracted : i32 to index + %2 = arith.cmpi ule, %1, %c3999 : index + %3 = scf.if %2 -> (tensor<4000x4x9xf32>) { + %4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 393749]) + %5 = xla_gpu.apply_indexing #map2(%th_x in [0, 127], %bl_x in [0, 393749]) + %elem = tensor.extract %arg2[%0, %c0, %4, %5] : tensor<1400x1x4x9xf32> + %atomic_rmw = xla_gpu.atomic_rmw %arg3[%1, %4, %5] : tensor<4000x4x9xf32> { + ^bb0(%arg4: f32): + %6 = arith.addf %arg4, %elem : f32 + xla_gpu.yield %6 : f32 + } + scf.yield %atomic_rmw : tensor<4000x4x9xf32> + } else { + scf.yield %arg3 : tensor<4000x4x9xf32> + } + return %3 : tensor<4000x4x9xf32> +} +// CHECK-LABEL: func.func @if_op +// CHECK-NOT: builtin.unrealized_conversion_cast +// CHECK: scf.if %{{.*}} -> (tensor<144000xf32>) { +// CHECK-COUNT-2: scf.yield %{{.*}} : tensor<144000xf32> +// CHECK: return %{{.*}} : tensor<144000xf32> + +// ----- + +func.func @dangling_cast(%arg0: tensor<6xf32>, %arg1: index) -> i32 { + %v = tensor.extract %arg0[%arg1] : tensor<6xf32> + %cast = builtin.unrealized_conversion_cast %v : f32 to i32 + func.return %cast : i32 +} +// CHECK: FlattenTensorsPass failed to converge From ce77156c4caa0a0c39803e38884bb1461f599b56 Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Thu, 25 Jul 2024 07:44:20 -0700 Subject: [PATCH 155/376] PR #15292: [NVIDIA GPU] Disable post layout assignment collective pipeliner by default Imported from GitHub PR https://github.com/openxla/xla/pull/15292 This is to address the reduce-scatter pipelining issue after enabling post layout assignment collective pipeline by default. Copybara import of the project: -- cd654152be1a3e4b08c5c8d8a5a460d06d0c1113 by TJ Xu : Disable post layout assignment collective pipeliner by default Merging this change closes #15292 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15292 from Tixxx:tixxx/disable_post_layout_pipeliner cd654152be1a3e4b08c5c8d8a5a460d06d0c1113 PiperOrigin-RevId: 655951744 --- xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 6046ff64ed6e3a..9b71d26ac6217a 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -171,7 +171,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_pipelined_reduce_scatter(false); opts.set_xla_gpu_enable_pipelined_p2p(false); - opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); + opts.set_xla_gpu_run_post_layout_collective_pipeliner(false); opts.set_xla_gpu_collective_permute_decomposer_threshold( std::numeric_limits::max()); From b918e8c6890574720e1fd231d3222dca9e555e93 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 25 Jul 2024 08:07:29 -0700 Subject: [PATCH 156/376] [xla:cpu] Consistently use NotFound for not found symbols and custom calls PiperOrigin-RevId: 655958081 --- xla/python/xla_client_test.py | 2 +- xla/service/cpu/cpu_executable.cc | 2 +- xla/service/cpu/runtime/custom_call_thunk.cc | 11 ++--------- xla/service/cpu/runtime_handle_ffi_call.cc | 2 +- xla/tests/custom_call_test.cc | 5 ----- 5 files changed, 5 insertions(+), 17 deletions(-) diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index bc108056e7e9d3..b84e094b1d841b 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -458,7 +458,7 @@ def testCustomCallWithUnifiedApiUnknownTarget(self): .API_VERSION_STATUS_RETURNING_UNIFIED, ) with self.assertRaisesRegex( - xla_client.XlaRuntimeError, expected_regex="INVALID_ARGUMENT" + xla_client.XlaRuntimeError, expected_regex="NOT_FOUND" ): self._Execute(c, arguments=()) diff --git a/xla/service/cpu/cpu_executable.cc b/xla/service/cpu/cpu_executable.cc index 6dde3faa3ea40b..b1a8685a84d72a 100644 --- a/xla/service/cpu/cpu_executable.cc +++ b/xla/service/cpu/cpu_executable.cc @@ -147,7 +147,7 @@ absl::StatusOr> CpuExecutable::Create( // We expect to find the symbol provided with entry_function_name; otherwise // this is an internal error. if (!sym) { - return absl::InvalidArgumentError( + return absl::NotFoundError( absl::StrCat("Symbol ", entry_function_name, " not found.")); } // getAddress can do work under the hood in the jit, so it needs to be diff --git a/xla/service/cpu/runtime/custom_call_thunk.cc b/xla/service/cpu/runtime/custom_call_thunk.cc index 217ac1d995539f..1161673db8764b 100644 --- a/xla/service/cpu/runtime/custom_call_thunk.cc +++ b/xla/service/cpu/runtime/custom_call_thunk.cc @@ -157,11 +157,7 @@ tsl::AsyncValueRef CustomCallThunk::CallTypedFFI( // Find the registered FFI handler for this target. auto handler = ffi::FindHandler(target_name_, "Host"); if (!handler.ok()) { - // Overwrite the returned error code (kNotFound) to kInternal to match the - // original CPU implementation. - // TODO(penporn): Change this to kUnimplemented to match the GPU backend - // when thunks is the only runtime for CPU. - return Internal( + return NotFound( "No registered implementation for FFI custom call to %s for Host", target_name_); } @@ -225,10 +221,7 @@ tsl::AsyncValueRef CustomCallThunk::CallUntypedAPI( void* call_target = CustomCallTargetRegistry::Global()->Lookup(target_name_, "Host"); if (!call_target) { - // Use kInternal to match the original CPU implementation. - // TODO(penporn): Change this to kUnimplemented to match the GPU backend - // when thunks is the only runtime for CPU. - return Internal( + return NotFound( "No registered implementation for untyped custom call to %s for Host", target_name_); } diff --git a/xla/service/cpu/runtime_handle_ffi_call.cc b/xla/service/cpu/runtime_handle_ffi_call.cc index 722aba6b6dbdd8..3c9c8d1e3d3eb6 100644 --- a/xla/service/cpu/runtime_handle_ffi_call.cc +++ b/xla/service/cpu/runtime_handle_ffi_call.cc @@ -108,7 +108,7 @@ static absl::Status BuildAndCallFfi( ffi::FindHandler(target_name, "Host"); if (!registration.ok()) { - return absl::UnimplementedError( + return absl::NotFoundError( absl::StrCat("No registered implementation for custom call to ", target_name, " for Host.")); } diff --git a/xla/tests/custom_call_test.cc b/xla/tests/custom_call_test.cc index 2a8e72c2bd34d8..2d0f370c664a4b 100644 --- a/xla/tests/custom_call_test.cc +++ b/xla/tests/custom_call_test.cc @@ -925,11 +925,6 @@ XLA_TEST_F(FfiCustomCallTest, FfiUnknownTarget) { module->AddEntryComputation(builder.Build()); auto status = Execute(std::move(module), {}).status(); - // NOTE: In the current CPU implementation, the 'kInternal' status code is - // returned when the target is not found. This behavior differs from that of - // the GPU, which returns 'kUnimplemented' in such case. When the CPU adopts - // the thunks runtime, the status code will be unified across both backends. - EXPECT_EQ(status.code(), absl::StatusCode::kInternal); EXPECT_THAT(status.message(), HasSubstr("No registered implementation")); } From 60e0497d9b2e40180073975a273b675650ae7603 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Jul 2024 08:13:54 -0700 Subject: [PATCH 157/376] Refactor the heartbeat lambda function in coordination service agent into a private member function. PiperOrigin-RevId: 655959872 --- .../distributed_runtime/coordination/BUILD | 1 + .../coordination_service_agent.cc | 123 +++++++++--------- 2 files changed, 66 insertions(+), 58 deletions(-) diff --git a/xla/tsl/distributed_runtime/coordination/BUILD b/xla/tsl/distributed_runtime/coordination/BUILD index be2419631ff962..3528a89a5103fa 100644 --- a/xla/tsl/distributed_runtime/coordination/BUILD +++ b/xla/tsl/distributed_runtime/coordination/BUILD @@ -143,6 +143,7 @@ tsl_gpu_library( "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/framework:cancellation", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc index ef3178ff9ff41e..2681333105991f 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/functional/bind_front.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -140,6 +141,8 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { private: absl::Status ShutdownInternal(); + // Starts sending heartbeats to the coordination service. + void StartSendingHeartbeats(); Env* env_ = nullptr; // Not owned. const uint64_t incarnation_id_ = random::New64(); @@ -308,66 +311,70 @@ absl::Status CoordinationServiceAgentImpl::Connect() { } LOG(INFO) << "Coordination agent has successfully connected."; - heartbeat_thread_.reset( - env_->StartThread(ThreadOptions(), kHeartbeatThread, [this]() -> void { - HeartbeatRequest request; - *request.mutable_source_task() = task_; - request.set_incarnation(incarnation_id_); - HeartbeatResponse response; - const int64_t heartbeat_interval_ms = - configs_.heartbeat_timeout_in_ms() > 0 - ? configs_.heartbeat_timeout_in_ms() / 2 - : absl::ToInt64Milliseconds(kDefaultHeartbeatTimeout) / 2; - CallOptions call_opts; - call_opts.SetTimeout(heartbeat_interval_ms); - - while (true) { - absl::Status status; - absl::Notification n; - // Heartbeat RPC implementation automatically retries to tolerate - // transient network failures. - VLOG(10) << "HeartbeatRequest: " << request.DebugString(); - leader_client_->HeartbeatAsync(&call_opts, &request, &response, - [&](absl::Status s) { - status = s; - n.Notify(); - }); - n.WaitForNotification(); - VLOG(10) << "HeartbeatResponse: " << status; - if (!status.ok()) { - // Ignore heartbeat errors and exit thread if shutting down. For - // example, the agent may send a heartbeat right after Shutdown() - // started, but before StopHeartbeat() and end of Shutdown(). This - // results in an unexpected heartbeat error. - // Waiting for a second allows us to identify if errors are due to - // inflight heartbeats sent during shutdown and can be ignored. - absl::SleepFor(absl::Seconds(1)); - { - absl::MutexLock l(&heartbeat_thread_shutdown_mu_); + heartbeat_thread_.reset(env_->StartThread( + ThreadOptions(), kHeartbeatThread, + absl::bind_front(&CoordinationServiceAgentImpl::StartSendingHeartbeats, + this))); + return absl::OkStatus(); +} - if (shutting_down_) { - return; - } - } - SetError(status); - } else if (response.leader_incarnation() != leader_incarnation_) { - SetError(MakeCoordinationError( - absl::AbortedError("Leader incarnation ID mismatch: the " - "coordination leader has restarted."))); - } - // Send next heartbeat after an interval. - { - absl::MutexLock l(&heartbeat_thread_shutdown_mu_); - heartbeat_thread_cv_.WaitWithTimeout( - &heartbeat_thread_shutdown_mu_, - absl::Milliseconds(heartbeat_interval_ms)); - if (shutting_down_) { - return; - } - } +void CoordinationServiceAgentImpl::StartSendingHeartbeats() { + HeartbeatRequest request; + *request.mutable_source_task() = task_; + request.set_incarnation(incarnation_id_); + HeartbeatResponse response; + const int64_t heartbeat_interval_ms = + configs_.heartbeat_timeout_in_ms() > 0 + ? configs_.heartbeat_timeout_in_ms() / 2 + : absl::ToInt64Milliseconds(kDefaultHeartbeatTimeout) / 2; + CallOptions call_opts; + call_opts.SetTimeout(heartbeat_interval_ms); + + while (true) { + absl::Status status; + absl::Notification n; + // Heartbeat RPC implementation automatically retries to tolerate + // transient network failures. + VLOG(10) << "HeartbeatRequest: " << request.DebugString(); + leader_client_->HeartbeatAsync(&call_opts, &request, &response, + [&](absl::Status s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + VLOG(10) << "HeartbeatResponse: " << status; + if (!status.ok()) { + // Ignore heartbeat errors and exit thread if shutting down. For + // example, the agent may send a heartbeat right after Shutdown() + // started, but before StopHeartbeat() and end of Shutdown(). This + // results in an unexpected heartbeat error. + // Waiting for a second allows us to identify if errors are due to + // inflight heartbeats sent during shutdown and can be ignored. + absl::SleepFor(absl::Seconds(1)); + { + absl::MutexLock l(&heartbeat_thread_shutdown_mu_); + + if (shutting_down_) { + return; } - })); - return absl::OkStatus(); + } + SetError(status); + } else if (response.leader_incarnation() != leader_incarnation_) { + SetError(MakeCoordinationError( + absl::AbortedError("Leader incarnation ID mismatch: the " + "coordination leader has restarted."))); + } + // Send next heartbeat after an interval. + { + absl::MutexLock l(&heartbeat_thread_shutdown_mu_); + heartbeat_thread_cv_.WaitWithTimeout( + &heartbeat_thread_shutdown_mu_, + absl::Milliseconds(heartbeat_interval_ms)); + if (shutting_down_) { + return; + } + } + } } absl::Status CoordinationServiceAgentImpl::WaitForAllTasks( From 0dbe10f4e6d7173c85f064c4cf5aa004c70471bc Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Thu, 25 Jul 2024 09:24:13 -0700 Subject: [PATCH 158/376] [XLA:GPU] Fix error message in profile guided latency estimator. Print actually found instructions, not missing instructions twice. PiperOrigin-RevId: 655979840 --- xla/service/profile_guided_latency_estimator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/service/profile_guided_latency_estimator.cc b/xla/service/profile_guided_latency_estimator.cc index b17d762b01fae2..d8e20f2445c4a4 100644 --- a/xla/service/profile_guided_latency_estimator.cc +++ b/xla/service/profile_guided_latency_estimator.cc @@ -188,7 +188,7 @@ absl::Status ProfileGuidedLatencyEstimator::CheckAccuracy( ProfileStatisticsAggregator::Statistics stats = aggregator_->GetStats(); size_t missing_instructions_count = stats.missing_instructions.size(); if (missing_instructions_count > 0) { - LOG(ERROR) << "Found " << missing_instructions_count + LOG(ERROR) << "Found " << stats.found_instructions_count << " instructions from the profile."; LOG(ERROR) << "Missing " << missing_instructions_count << " instructions from the profile."; From b4eefdd50915c42b12616f78ed1efc1d2a5bddfa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Jul 2024 09:28:24 -0700 Subject: [PATCH 159/376] Simplify and generalize the strategy generation code for dot ops. Rather than explicitly generating strategies corresponding to different sets of dimensions being sharded, we now generate strategies in a more principled and general manner. PiperOrigin-RevId: 655980955 --- .../auto_sharding_dot_handler.cc | 672 +++++++----------- .../auto_sharding/auto_sharding_test.cc | 34 +- 2 files changed, 256 insertions(+), 450 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 36d03717e9ca00..dbefe93f8fd04f 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -31,6 +32,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.h" @@ -166,12 +168,46 @@ class HandlerBase { } } + // Given a set of tensor dims, and a set of mesh dims, enumerates all mappings + // where a subset of all tensor dims is mapped to a subset of mesh dims, such + // that each tensor dim is mapped to at most mesh dim, and no two tensor dims + // are mapped to the same mesh dim. + // TODO(b/226977360): We might need to generalize this to also allow cases + // where a tensor dim can be mapped to multiple mesh dims. + void EnumerateGeneral(std::function split_func, + int tensor_rank, int current_tensor_dim, + const absl::flat_hash_set& unassigned_mesh_dims, + const DimMap& current_dim_map) { + if (current_tensor_dim == tensor_rank) { + split_func(current_dim_map); + return; + } + // current_tensor_dim is unsharded + EnumerateGeneral(split_func, tensor_rank, current_tensor_dim + 1, + unassigned_mesh_dims, current_dim_map); + // current_tensor_dim is sharded across one of the remaining mesh dims + for (int mesh_dim : unassigned_mesh_dims) { + DimMap updated_dim_map = current_dim_map; + updated_dim_map[current_tensor_dim] = mesh_dim; + absl::flat_hash_set updated_unassigned_mesh_dims = + unassigned_mesh_dims; + updated_unassigned_mesh_dims.erase( + updated_unassigned_mesh_dims.find(mesh_dim)); + EnumerateGeneral(split_func, tensor_rank, current_tensor_dim + 1, + updated_unassigned_mesh_dims, updated_dim_map); + } + } + // Enumerates *half* of the combinations (if inner & outer dims are the same). void EnumerateHalf(std::function split_func, size_t num_outer_dims = 2, size_t num_inner_dims = 2) { Enumerate(split_func, num_outer_dims, num_inner_dims, true); } + // Sorts strategies in the increasing order of their memory costs. Anecdotal + // experience suggests that such a sorted list of strategies works better + void SortStrategies(); + std::unique_ptr& strategy_group_; StrategyMap& strategy_map_; const HloInstruction* ins_; @@ -212,33 +248,11 @@ class DotHandler : public HandlerBase { ~DotHandler() override = default; - void SplitLhsSpaceRhsSpace(); - - void SplitLhsSpaceOnly(); - - void SplitRhsSpaceOnly(); - - void SplitLhsSpaceBothContract(); - - void SplitRhsSpaceBothContract(); - - void SplitOneBatchDim(); - - void SplitTwoBatchDims(); - - void SplitBatchDimLhsSpace(); - - void SplitBatchDimRhsSpace(); - - void SplitBatchDimBothContract(); + std::string GenerateNameForDotSharding(const DimMap& output_dim_map, + const DimMap& lhs_dim_map); - void SplitBothContractTwoDims(); - - void RecomputeSplitBothContract(); - - void Add1DDataParallel(); - - void Add1DBatchSplit(); + void GenerateDotShardingStrategiesFromOutputSharding( + const DimMap& output_dim_map); void AppendAllGatherWindowedEinsumStrategyForOperand( int operand_num, const std::string& name, const DimMap& lhs_dim_map, @@ -256,10 +270,13 @@ class DotHandler : public HandlerBase { bool is_dot_; int64_t space_base_dim_; tsl::protobuf::RepeatedField lhs_space_dims_, rhs_space_dims_; + tsl::protobuf::RepeatedField out_lhs_space_dims_, + out_rhs_space_dims_; tsl::protobuf::RepeatedField lhs_con_dims_; tsl::protobuf::RepeatedField rhs_con_dims_; tsl::protobuf::RepeatedField lhs_batch_dims_; tsl::protobuf::RepeatedField rhs_batch_dims_; + std::vector out_batch_dims_; }; class ConvHandler : public HandlerBase { @@ -439,6 +456,17 @@ std::optional HandlerBase::GetShardingFromUser( return ins_clone->sharding(); } +void HandlerBase::SortStrategies() { + absl::c_sort(strategy_group_->strategies, + [](const ShardingStrategy& s1, const ShardingStrategy& s2) { + if (s1.memory_cost == s2.memory_cost) { + return s1.name < s2.name; + } else { + return s1.memory_cost < s2.memory_cost; + } + }); +} + /************** DotHandler function definitions **************/ DotHandler::DotHandler(std::unique_ptr& strategy_group, @@ -458,9 +486,18 @@ DotHandler::DotHandler(std::unique_ptr& strategy_group, lhs_con_dims_(ins->dot_dimension_numbers().lhs_contracting_dimensions()), rhs_con_dims_(ins->dot_dimension_numbers().rhs_contracting_dimensions()), lhs_batch_dims_(ins->dot_dimension_numbers().lhs_batch_dimensions()), - rhs_batch_dims_(ins->dot_dimension_numbers().rhs_batch_dimensions()) { + rhs_batch_dims_(ins->dot_dimension_numbers().rhs_batch_dimensions()), + out_batch_dims_( + ins->dot_dimension_numbers().rhs_batch_dimensions().size()) { std::tie(lhs_space_dims_, rhs_space_dims_) = GetSpaceDims(lhs_->shape(), rhs_->shape(), ins->dot_dimension_numbers()); + for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { + out_lhs_space_dims_.Add(space_base_dim_ + i); + } + for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) { + out_rhs_space_dims_.Add(space_base_dim_ + lhs_space_dims_.size() + i); + } + std::iota(out_batch_dims_.begin(), out_batch_dims_.end(), 0); CHECK_EQ(lhs_con_dims_.size(), rhs_con_dims_.size()); CHECK_EQ(lhs_batch_dims_.size(), rhs_batch_dims_.size()); } @@ -484,6 +521,7 @@ DotHandler::DotHandler( for (auto dim_idx : conv_as_dot_dims.batch_dims) { if (dim_idx.lhs >= 0) lhs_batch_dims_.Add(dim_idx.lhs); if (dim_idx.rhs >= 0) rhs_batch_dims_.Add(dim_idx.rhs); + if (dim_idx.output >= 0) out_batch_dims_.push_back(dim_idx.output); } for (auto dim_idx : conv_as_dot_dims.contracting_dims) { @@ -493,365 +531,207 @@ DotHandler::DotHandler( for (auto dim_idx : conv_as_dot_dims.lhs_non_contracting_dims) { if (dim_idx.lhs >= 0) lhs_space_dims_.Add(dim_idx.lhs); + if (dim_idx.output >= 0) out_lhs_space_dims_.Add(dim_idx.output); } for (auto dim_idx : conv_as_dot_dims.rhs_non_contracting_dims) { if (dim_idx.rhs >= 0) rhs_space_dims_.Add(dim_idx.rhs); + if (dim_idx.output >= 0) out_rhs_space_dims_.Add(dim_idx.output); } } -void DotHandler::SplitLhsSpaceRhsSpace() { - auto func = [this](const Enumeration& e) { - const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_space_dims_[e.j], e.mesh_dims[1]}}; - std::string name = - absl::StrFormat("SS = SR x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{ - {space_base_dim_ + e.i, e.mesh_dims[0]}, - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.j, - e.mesh_dims[1]}}; +std::string DotHandler::GenerateNameForDotSharding(const DimMap& output_dim_map, + const DimMap& lhs_dim_map) { + std::string name; + + auto append_shardings_for_dims = [&name](absl::Span out_dims, + const DimMap& dim_map, + absl::string_view identifier) { + for (size_t i = 0; i < out_dims.size(); ++i) { + int output_batch_dim = out_dims[i]; + int mesh_dim = -1; + auto it = dim_map.find(output_batch_dim); + if (it != dim_map.end() && it->second >= 0) { + mesh_dim = it->second; + } + absl::StrAppend(&name, identifier, mesh_dim); } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); }; - Enumerate(func, lhs_space_dims_.size(), rhs_space_dims_.size()); -} -void DotHandler::SplitLhsSpaceOnly() { - auto func = [this](const Enumeration& e) { - const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}, - {lhs_space_dims_[e.j], e.mesh_dims[1]}}; - std::string name = absl::StrFormat("SSR = SSR x RR @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}, - {space_base_dim_ + e.j, e.mesh_dims[1]}}; + // Output batch dims + append_shardings_for_dims(out_batch_dims_, output_dim_map, + /*identifier=*/"b"); + // LHS space dims + append_shardings_for_dims(out_lhs_space_dims_, output_dim_map, + /*identifier=*/"ls"); + // RHS space dims + append_shardings_for_dims(out_rhs_space_dims_, output_dim_map, + /*identifier=*/"rs"); + // Contraction dims + append_shardings_for_dims(lhs_con_dims_, lhs_dim_map, + /*identifier=*/"r"); + + bool contraction_dim_sharded = false; + for (size_t i = 0; i < lhs_con_dims_.size(); ++i) { + if (auto it = lhs_dim_map.find(lhs_con_dims_[i]); + it != lhs_dim_map.end() && it->second >= 0) { + contraction_dim_sharded = + contraction_dim_sharded || (device_mesh_.dim(it->second) > 1); } - MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_); - }; - EnumerateHalf(func, lhs_space_dims_.size(), lhs_space_dims_.size()); -} + } -void DotHandler::SplitRhsSpaceOnly() { - auto func = [this](const Enumeration& e) { - const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[0]}, - {rhs_space_dims_[e.j], e.mesh_dims[1]}}; - std::string name = absl::StrFormat("RSS = RR x RSS @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{ - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, - e.mesh_dims[0]}, - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.j, - e.mesh_dims[1]}}; - } - MaybeAppend(name, {}, rhs_dim_map, out_dim_map, device_mesh_); - }; - EnumerateHalf(func, rhs_space_dims_.size(), rhs_space_dims_.size()); + if (contraction_dim_sharded) { + absl::StrAppend(&name, "|allreduce"); + } + return name; } -void DotHandler::SplitLhsSpaceBothContract() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = - absl::StrFormat("SR = SS x SR @ {%s} (allreduce @ %d)", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); - const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[0]}, - {lhs_con_dims_[e.j], e.mesh_dims[1]}}; - const DimMap rhs_dim_map = {{rhs_con_dims_[e.j], e.mesh_dims[1]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{space_base_dim_ + e.i, e.mesh_dims[0]}}; +bool IsFullyReplicatedSharding(const DimMap& dim_map, + const Array& device_mesh) { + if (dim_map.empty()) { + return true; + } + for (const auto& [_, mesh_dim] : dim_map) { + if (device_mesh.dim(mesh_dim) > 1) { + return false; } - - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - Enumerate(func, lhs_space_dims_.size(), lhs_con_dims_.size()); + } + return true; } -void DotHandler::SplitRhsSpaceBothContract() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1) return; - std::string name = - absl::StrFormat("RS = RS x SS @ {%s} (allreduce @ %d)", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[0]); - const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, - {rhs_con_dims_[e.j], e.mesh_dims[0]}}; - const DimMap lhs_dim_map = {{lhs_con_dims_[e.j], e.mesh_dims[0]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{ - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, - e.mesh_dims[1]}}; - } - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - Enumerate(func, rhs_space_dims_.size(), lhs_con_dims_.size()); +bool IsFullyReplicatedStrategy(const DimMap& output_dim_map, + const DimMap& lhs_dim_map, + const DimMap& rhs_dim_map, + const Array& device_mesh) { + return IsFullyReplicatedSharding(output_dim_map, device_mesh) && + IsFullyReplicatedSharding(lhs_dim_map, device_mesh) && + IsFullyReplicatedSharding(rhs_dim_map, device_mesh); } -void DotHandler::SplitOneBatchDim() { - if (absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) != 1) { - return; - } - auto func = [this](const Enumeration& e) { - const DimMap lhs_dim_map = {{lhs_batch_dims_[e.i], e.j}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[e.i], e.j}}; - std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d}", e.i, e.j); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{e.i, e.j}}; - } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); - }; - Enumerate(func, lhs_batch_dims_.size(), device_mesh_.num_dimensions()); +bool IsFullySharded(const DimMap& dim_map, int num_mesh_dims) { + return dim_map.size() >= num_mesh_dims; } -void DotHandler::SplitTwoBatchDims() { - if (lhs_batch_dims_.size() != 2) return; - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - const DimMap lhs_dim_map = {{lhs_batch_dims_[0], e.mesh_dims[0]}, - {lhs_batch_dims_[1], e.mesh_dims[1]}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[0], e.mesh_dims[0]}, - {rhs_batch_dims_[1], e.mesh_dims[1]}}; - std::string name = - absl::StrFormat("Sb = Sb x Sb @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{0, e.mesh_dims[0]}, {1, e.mesh_dims[1]}}; +void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( + const DimMap& output_dim_map) { + DimMap lhs_dim_map, rhs_dim_map; + absl::flat_hash_set used_mesh_dims; + + // Propagate shardings for batch dimensions + for (size_t i = 0; i < out_batch_dims_.size(); ++i) { + int output_batch_dim = out_batch_dims_[i]; + int lhs_batch_dim = lhs_batch_dims_[i]; + int rhs_batch_dim = rhs_batch_dims_[i]; + auto it = output_dim_map.find(output_batch_dim); + if (it != output_dim_map.end() && it->second >= 0) { + int mesh_dim = it->second; + used_mesh_dims.insert(mesh_dim); + lhs_dim_map[lhs_batch_dim] = mesh_dim; + rhs_dim_map[rhs_batch_dim] = mesh_dim; } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); - }; - EnumerateHalf(func, lhs_batch_dims_.size(), lhs_batch_dims_.size()); -} + } -void DotHandler::SplitBatchDimLhsSpace() { - if (lhs_batch_dims_.empty()) return; - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = absl::StrFormat("SbSi = SbSi x SbR @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - const DimMap lhs_dim_map = {{lhs_space_dims_[e.i], e.mesh_dims[1]}, - {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[e.j], e.mesh_dims[0]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{e.j, e.mesh_dims[0]}, - {space_base_dim_ + e.i, e.mesh_dims[1]}}; + // Propagate shardings for spatial dimensions + // - LHS space dims + for (size_t i = 0; i < lhs_space_dims_.size(); ++i) { + int lhs_space_dim = lhs_space_dims_[i]; + int output_space_dim = out_lhs_space_dims_[i]; + auto it = output_dim_map.find(output_space_dim); + if (it != output_dim_map.end() && it->second >= 0) { + int mesh_dim = it->second; + used_mesh_dims.insert(mesh_dim); + lhs_dim_map[lhs_space_dim] = mesh_dim; } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); - }; - Enumerate(func, lhs_space_dims_.size(), lhs_batch_dims_.size()); -} + } -void DotHandler::SplitBatchDimRhsSpace() { - if (lhs_batch_dims_.empty()) return; - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = absl::StrFormat("SbSj = SbR x SbSj @ {%s}", - absl::StrJoin(e.mesh_dims, ",")); - const DimMap rhs_dim_map = {{rhs_space_dims_[e.i], e.mesh_dims[1]}, - {rhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap lhs_dim_map = {{lhs_batch_dims_[e.j], e.mesh_dims[0]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{ - {e.j, e.mesh_dims[0]}, - {space_base_dim_ + static_cast(lhs_space_dims_.size()) + e.i, - e.mesh_dims[1]}}; + // - RHS space dims + for (size_t i = 0; i < rhs_space_dims_.size(); ++i) { + int rhs_space_dim = rhs_space_dims_[i]; + int output_space_dim = out_rhs_space_dims_[i]; + auto it = output_dim_map.find(output_space_dim); + if (it != output_dim_map.end() && it->second >= 0) { + int mesh_dim = it->second; + used_mesh_dims.insert(mesh_dim); + rhs_dim_map[rhs_space_dim] = mesh_dim; } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); - }; - Enumerate(func, rhs_space_dims_.size(), lhs_batch_dims_.size()); -} + } -void DotHandler::SplitBatchDimBothContract() { - if (lhs_batch_dims_.empty()) return; - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - std::string name = - absl::StrFormat("SbR = SbSk x SbSk @ {%s} (allreduce @ %d}", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); - const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[1]}, - {lhs_batch_dims_[e.j], e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[1]}, - {rhs_batch_dims_[e.j], e.mesh_dims[0]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{e.j, e.mesh_dims[0]}}; + // Skip fully the replicated strategy here as we add that outside of + // HandleDot in auto_sharding_strategy. + // TODO(b/348372403): Consolidate the generation of all dot strategies + // (including replicated strategies) in one place. + if (!IsFullyReplicatedStrategy(output_dim_map, lhs_dim_map, rhs_dim_map, + device_mesh_) && + // This second condition is added to ensure parity with the older strategy + // generation code. Removing it will only increase the search space. + IsFullySharded(output_dim_map, device_mesh_.num_dimensions())) { + MaybeAppend(GenerateNameForDotSharding(output_dim_map, lhs_dim_map), + lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_); + } + + // Generate shardings for contraction dimensions + if (used_mesh_dims.size() == device_mesh_.num_dimensions()) { + return; + } + + absl::flat_hash_set unused_mesh_dims; + for (size_t i = 0; i < device_mesh_.num_dimensions(); ++i) { + if (!used_mesh_dims.contains(i) && device_mesh_.dim(i) > 1) { + unused_mesh_dims.insert(i); } - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - Enumerate(func, lhs_con_dims_.size(), lhs_batch_dims_.size()); -} + } -void DotHandler::SplitBothContractTwoDims() { - if (lhs_con_dims_.size() < 2 || rhs_con_dims_.size() < 2) return; - auto func = [this](const Enumeration& e) { - // Applies when there are more than one contracting dimension. - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) + if (unused_mesh_dims.empty()) { + return; + } + + std::vector reduction_dims(lhs_con_dims_.size()); + std::iota(reduction_dims.begin(), reduction_dims.end(), 0); + + auto split_func = [&](const DimMap& reduction_dim_map) { + if (reduction_dim_map.empty()) { return; - std::string name = absl::StrFormat("RR = SS x SS @ {%s} (allreduce @ {%s}}", - absl::StrJoin(e.mesh_dims, ","), - absl::StrJoin(e.mesh_dims, ", ")); - const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}, - {lhs_con_dims_[e.j], e.mesh_dims[1]}}; - const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}, - {rhs_con_dims_[e.j], e.mesh_dims[1]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{}; } - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0], - e.mesh_dims[1]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - EnumerateHalf(func, lhs_con_dims_.size(), lhs_con_dims_.size()); -} -void DotHandler::RecomputeSplitBothContract() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - if (!option_.allow_recompute_heavy_op) { - return; + DimMap lhs_dim_map_with_contractions = lhs_dim_map; + DimMap rhs_dim_map_with_contractions = rhs_dim_map; + for (const auto& [reducton_dim_index, mesh_dim] : reduction_dim_map) { + lhs_dim_map_with_contractions + [lhs_con_dims_[reduction_dims[reducton_dim_index]]] = mesh_dim; + rhs_dim_map_with_contractions + [rhs_con_dims_[reduction_dims[reducton_dim_index]]] = mesh_dim; } - std::string name = absl::StrFormat("RR = RS x SR @ {%d} (allreduce @ %d)", - e.mesh_dims[0], e.mesh_dims[0]); - const DimMap lhs_dim_map = {{lhs_con_dims_[e.i], e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_con_dims_[e.i], e.mesh_dims[0]}}; - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{}; + // Skip fully the replicated strategy here as we add that outside of + // HandleDot in auto_sharding_strategy. + // TODO: Fix the above + if (IsFullyReplicatedStrategy(output_dim_map, lhs_dim_map_with_contractions, + rhs_dim_map_with_contractions, + device_mesh_)) { + return; } - double compute_cost = GetDotConvReplicationPenalty( - ins_, instruction_id_, /* window */ 10, - instruction_sequence_, hlo_cost_analysis_) / - device_mesh_.dim(e.mesh_dims[0]); - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, - compute_cost, communication_cost_fn); - }; - Enumerate(func, lhs_con_dims_.size(), 1); -} + CHECK(!lhs_dim_map_with_contractions.empty()); + CHECK(!rhs_dim_map_with_contractions.empty()); -void DotHandler::Add1DDataParallel() { - if (device_mesh_.dim(0) > 1 && - absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) > 1) { - int mesh_dim = 0; - int64_t num_devices = device_mesh_1d_.dim(mesh_dim); - - // Si = Si x R @ 0 - for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { - const DimMap lhs_dim_map = {{lhs_space_dims_[i], mesh_dim}}; - if (lhs_->shape().dimensions(lhs_space_dims_[i]) < num_devices) { - continue; - } - if (option_.only_allow_divisible_intermediate && - !IsDivisible(lhs_->shape().dimensions(lhs_space_dims_[i]), - num_devices)) { - continue; - } - std::string name = absl::StrFormat("Si = Si x R @ %d", mesh_dim); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{space_base_dim_ + i, mesh_dim}}; + auto communication_cost_fn = [&](const HloSharding& output_sharding) { + double memory_cost = + ByteSizeOfShapeWithSharding(ins_->shape(), output_sharding); + double total_cost = 0; + for (const auto& [_, mesh_dim] : reduction_dim_map) { + total_cost += cluster_env_.AllReduceCost(memory_cost, mesh_dim); } - MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_1d_); - } + return total_cost; + }; - // R = Sk x Sk @ (allreduce @ 0) - for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { - const DimMap lhs_dim_map = {{lhs_con_dims_[i], mesh_dim}}; - const DimMap rhs_dim_map = {{rhs_con_dims_[i], mesh_dim}}; - if (lhs_->shape().dimensions(lhs_con_dims_[i]) < num_devices) { - continue; - } - if (option_.only_allow_divisible_intermediate && - !IsDivisible(lhs_->shape().dimensions(lhs_con_dims_[i]), - num_devices)) { - continue; - } - std::string name = absl::StrFormat("R = Sk x Sk @ %d (allreduce @ %d)", - mesh_dim, mesh_dim); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{}; - } - auto communication_cost_fn = [this, - mesh_dim](const HloSharding& output_spec) { - double memory_cost = - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec); - return cluster_env_.AllReduceCost(memory_cost, mesh_dim); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_1d_, - 0, communication_cost_fn); - } - } -} + MaybeAppend(GenerateNameForDotSharding(output_dim_map, + lhs_dim_map_with_contractions), + lhs_dim_map_with_contractions, rhs_dim_map_with_contractions, + output_dim_map, device_mesh_, + /*compute_cost=*/0, communication_cost_fn); + }; -void DotHandler::Add1DBatchSplit() { - if (device_mesh_.dim(0) > 1 && - absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) > 1) { - int mesh_dim = 0; - for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) { - const DimMap lhs_dim_map = {{lhs_batch_dims_[i], mesh_dim}}; - const DimMap rhs_dim_map = {{rhs_batch_dims_[i], mesh_dim}}; - std::string name = - absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim); - std::optional out_dim_map = std::nullopt; - if (is_dot_) { - out_dim_map = DimMap{{i, mesh_dim}}; - } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_1d_); - } - } + EnumerateGeneral(split_func, reduction_dims.size(), + /*current_tensor_dim=*/0, unused_mesh_dims, + /*current_dim_map=*/{}); } void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand( @@ -957,91 +837,18 @@ void DotHandler::AppendReduceScatterWindowedEinsumStrategy( } absl::Status DotHandler::RegisterStrategies() { - // SS = SR x RS - // Split lhs space dim and rhs space dim. - SplitLhsSpaceRhsSpace(); - - // SSR = SSR x RR - // Split lhs space dims only if it has more than 1 space dims. - if (lhs_space_dims_.size() > 1) { - SplitLhsSpaceOnly(); - } - // RSS = RR x RSS - // Split rhs space dims only if it has more than 1 space dims. - if (rhs_space_dims_.size() > 1) { - SplitRhsSpaceOnly(); - } - - // SR = SS x SR - // Split lhs space dim and both contracting dims. - SplitLhsSpaceBothContract(); - - // RS = RS x SS - // Split rhs space dim and both contracting dims. - SplitRhsSpaceBothContract(); - - // RR = SS x SS - // Split two contracting dims on lhs and rhs. - SplitBothContractTwoDims(); - - // RR = RS x SR - // This is a special case where we allow splitting only one dim in the - // multi-dimensional mesh case. This allows some recomputation - // (e.g., the dense layer in the LM_head of BERT). - RecomputeSplitBothContract(); - - // Add 1d data parallel in multi-dimensional mesh - if (option_.allow_mixed_mesh_shape) { - Add1DDataParallel(); - } - - if (option_.batch_matmul_always_split_batch && !lhs_batch_dims_.empty() && - cluster_env_.non_zero_mesh_dims_.size() > 1) { - // If there is a batch dim and the device mesh is multi-dimensional, - // always split on batch dim. Clear all old strategies. - strategy_group_->strategies.clear(); - } - - // Sb = Sb x Sb - // Split one batch dim. Only used for 1d mesh - SplitOneBatchDim(); - - // SbSi = SbSi x SbR - // Split batch dim and lhs space dim - SplitBatchDimLhsSpace(); - - // SbSj = SbR x SbSj - // Split batch dim and rhs space dim - SplitBatchDimRhsSpace(); - - // SbSj = SbR x SbSj - // Split batch dim and contracting dim - SplitBatchDimBothContract(); - - if (option_.batch_matmul_always_split_batch && lhs_batch_dims_.size() == 2 && - absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) > 1) { - // If there are two batch dims, always split on these two dims. - // Clear all old strategies. - strategy_group_->strategies.clear(); + absl::flat_hash_set all_mesh_dims; + for (int i = 0; i < device_mesh_.num_dimensions(); ++i) { + all_mesh_dims.insert(i); } - - // Sb = Sb x Sb - // Split batch dims. - SplitTwoBatchDims(); - - if (option_.allow_mixed_mesh_shape) { - Add1DBatchSplit(); - } - - // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies - // and only keep the data parallel strategies. - if (option_.force_batch_dim_to_mesh_dim >= 0 && - batch_map_.contains(GetBatchDimMapKey(ins_))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), strategy_group_, - cluster_env_, batch_map_, option_)); - } - + EnumerateGeneral( + /*split_func=*/ + [&](const DimMap& output_dim_map) { + GenerateDotShardingStrategiesFromOutputSharding(output_dim_map); + }, + ins_->shape().rank(), /*current_tensor_dim=*/0, all_mesh_dims, + /*current_dim_map=*/{}); + SortStrategies(); return absl::OkStatus(); } @@ -1275,7 +1082,6 @@ absl::Status HandleConv(std::unique_ptr& strategy_group, batch_map, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); } - return absl::OkStatus(); } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 64c42be14fdfe5..4f9d533a4edcec 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -1458,34 +1458,33 @@ HloModule module ENTRY %entry { %param0 = f32[1024,1024]{0,1} parameter(0) %param1 = s32[128,1024,1]{2,1,0} parameter(1) - %gather = f32[128,1024,1024]{2,1,0} gather(f32[1024,1024]{0,1} %param0, s32[128,1024,1]{2,1,0} %param1), - offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, - index_vector_dim=2, slice_sizes={1,1024} + %gather = f32[128,1024,1024]{2,1,0} gather(f32[1024,1024]{0,1} %param0, s32[128,1024,1]{2,1,0} %param1), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,1024} %param2 = f32[1024,1024]{1,0} parameter(2), sharding={replicated} %reshape = f32[1024,1024,1]{2,1,0} reshape(param2) - ROOT convolution = f32[128,1024,1024]{2,1,0} convolution(gather, reshape), - window={size=1}, dim_labels=b0f_io0->b0f + ROOT convolution = f32[128,1024,1024]{2,1,0} convolution(gather, reshape), window={size=1}, dim_labels=b0f_io0->b0f })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); AutoShardingOption option; option.enable = true; - option.device_mesh_shape = {4, 1, 1}; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepInputOutputShardings; + option.device_mesh_shape = {4, 1}; option.device_mesh_ids = {0, 1, 2, 3}; - option.device_mesh_alpha = {1.0, 1.0, 1.0}; - option.device_mesh_beta = {0.01, 1.0, 1.0}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); EXPECT_TRUE(changed); - auto* gather = FindInstruction(module.get(), "gather"); - auto* conv = FindInstruction(module.get(), "convolution"); + const HloInstruction* gather = FindInstruction(module.get(), "gather"); + const HloInstruction* conv = FindInstruction(module.get(), "convolution"); ASSERT_NE(gather, nullptr); ASSERT_NE(conv, nullptr); - EXPECT_THAT(gather, op::Sharding("{devices=[1,4,1]0,1,2,3}")); - EXPECT_THAT(conv, op::Sharding("{devices=[1,4,1]0,1,2,3}")); - auto gather_sharding = gather->sharding(); - TF_EXPECT_OK(gather_sharding.Validate(gather->shape(), 4)); - auto conv_sharding = conv->sharding(); - TF_EXPECT_OK(conv_sharding.Validate(conv->shape(), 4)); + const HloSharding& gather_sharding = gather->sharding(); + EXPECT_EQ(gather_sharding.NumTiles(), 4); + EXPECT_OK(gather_sharding.Validate(gather->shape(), 4)); + const HloSharding& conv_sharding = conv->sharding(); + EXPECT_EQ(conv_sharding.NumTiles(), 4); + EXPECT_OK(conv_sharding.Validate(conv->shape(), 4)); } TEST_F(AutoShardingTest, MatmulMeshShape1DMeshShape) { @@ -1713,7 +1712,8 @@ TEST_F(AutoShardingTest, LargeSize) { option.device_mesh_shape = {1, 2, 4, 7}; option.device_mesh_alpha = {1.0, 1.0, 1.0, 1.0}; option.device_mesh_beta = {1.0, 1.0, 1.0, 1.0}; - RunMatMulAutoShardingWithOptions(option, 8, 2); + option.memory_budget_per_device = (8192 + 8192 * 2 + 8192 * 4 / 8); + RunMatMulAutoShardingWithOptions(option, 7, 1); } TEST_F(AutoShardingTest, InvalidOptions) { From c06be0cfbba596155d94499f23962ee7a3829316 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 25 Jul 2024 10:15:50 -0700 Subject: [PATCH 160/376] [xla:cpu] Create intra-op thread pool with same number of threads as PjrtClient thread pool This prevents possible deadlocks. PiperOrigin-RevId: 655997001 --- xla/pjrt/cpu/cpu_client.cc | 4 ++-- xla/pjrt/cpu/cpu_client_test.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index d7c7575ec06cbc..02b6524018a852 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -420,8 +420,8 @@ TfrtCpuClient::TfrtCpuClient( "XLATfrtCpuClient", num_threads)), async_work_runner_(std::make_unique( pjrt_client_thread_pool_.get())), - eigen_intraop_pool_(new tsl::thread::ThreadPool( - tsl::Env::Default(), "XLAEigen", DefaultThreadPoolSize())), + eigen_intraop_pool_(new tsl::thread::ThreadPool(tsl::Env::Default(), + "XLAEigen", num_threads)), eigen_intraop_device_( new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(), eigen_intraop_pool_->NumThreads())), diff --git a/xla/pjrt/cpu/cpu_client_test.cc b/xla/pjrt/cpu/cpu_client_test.cc index 2e71323b6ea68b..a66a2901d4cbab 100644 --- a/xla/pjrt/cpu/cpu_client_test.cc +++ b/xla/pjrt/cpu/cpu_client_test.cc @@ -99,7 +99,7 @@ TEST(TfrtCpuClientTest, MemorySpace) { } TEST(TfrtCpuClientTest, DonationWithExecutionError) { - constexpr char kProgram[] = + static constexpr char kProgram[] = R"( HloModule DonationWithExecutionError, input_output_alias={ {}: (0, {}, must-alias) } @@ -144,7 +144,7 @@ ENTRY DonationWithExecutionError() -> f32[2, 2] { } TEST(TfrtCpuClientTest, HloSnapshot) { - constexpr char kProgram[] = R"( + static constexpr char kProgram[] = R"( HloModule add ENTRY add { x = f32[3,2] parameter(0) From 9a0b4e61bd04b68919c56a3a8a43d6b906d002c4 Mon Sep 17 00:00:00 2001 From: Vladyslav Tsilytskyi Date: Thu, 25 Jul 2024 10:16:56 -0700 Subject: [PATCH 161/376] [xla:cpu] Add PadThunk implementation Reuse all the code from IrEmitter. PiperOrigin-RevId: 655997367 --- xla/service/cpu/benchmarks/BUILD | 16 +++++ .../cpu/benchmarks/pad_benchmark_test.cc | 66 +++++++++++++++++++ xla/service/cpu/ir_emitter.cc | 39 ++++++++--- xla/service/cpu/ir_emitter.h | 22 +++++-- xla/service/cpu/ir_emitter2.cc | 45 ++++++++++++- xla/service/cpu/ir_emitter2.h | 5 ++ xla/service/cpu/thunk_emitter.cc | 14 +++- xla/service/cpu/thunk_emitter.h | 3 + 8 files changed, 190 insertions(+), 20 deletions(-) create mode 100644 xla/service/cpu/benchmarks/pad_benchmark_test.cc diff --git a/xla/service/cpu/benchmarks/BUILD b/xla/service/cpu/benchmarks/BUILD index 3cec33edd0a749..440d7dcb30f2d6 100644 --- a/xla/service/cpu/benchmarks/BUILD +++ b/xla/service/cpu/benchmarks/BUILD @@ -249,3 +249,19 @@ xla_cc_test( "@tsl//tsl/platform:test_main", ], ) + +xla_cc_test( + name = "pad_benchmark_test", + srcs = ["pad_benchmark_test.cc"], + deps = [ + ":hlo_benchmark_runner", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test_benchmark", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/xla/service/cpu/benchmarks/pad_benchmark_test.cc b/xla/service/cpu/benchmarks/pad_benchmark_test.cc new file mode 100644 index 00000000000000..e82d75ab38195d --- /dev/null +++ b/xla/service/cpu/benchmarks/pad_benchmark_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" +#include "xla/shape_util.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla::cpu { + +static void BM_PadF32(benchmark::State& state) { + int64_t d0 = state.range(0); + + std::string_view hlo = R"( + HloModule pad_f32_$d0 + + ENTRY e { + input = f32[1,4,$d0,$d0,4] parameter(0) + value = f32[] parameter(1) + ROOT pad = pad(input, value), padding=0_0_0x0_-1_0x0_-1_0x-2_-2_0x-1_-1_3 + } + )"; + + std::minstd_rand0 engine; + + auto input_shape = ShapeUtil::MakeShape(F32, {1, 4, d0, d0, 4}); + auto value_shape = ShapeUtil::MakeShape(F32, {}); + auto p0 = + *LiteralUtil::CreateRandomLiteral(input_shape, &engine, 1.0f, 0.1f); + auto p1 = + *LiteralUtil::CreateRandomLiteral(value_shape, &engine, 1.0f, 0.1f); + + std::vector args = {&p0, &p1}; + CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); +} + +BENCHMARK(BM_PadF32) + ->MeasureProcessCPUTime() + ->Arg(128) + ->Arg(256) + ->Arg(512) + ->Arg(1024) + ->Arg(4096); + +} // namespace xla::cpu diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 75b1f3ccce2254..8a6b3619e3ab29 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -2300,6 +2300,22 @@ absl::Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { } absl::Status IrEmitter::HandlePad(HloInstruction* pad) { + CHECK_EQ(pad->operand_count(), 2); + const auto operand = pad->operand(0); + const auto padding_value = pad->operand(1); + + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(pad)); + + return HandlePad(pad, GetIrArrayFor(operand), GetIrArrayFor(padding_value), + GetIrArrayFor(pad)); +} + +absl::Status IrEmitter::HandlePad(HloInstruction* pad, + const llvm_ir::IrArray& operand_array, + const llvm_ir::IrArray& padding_value_array, + const llvm_ir::IrArray& output_array) { + CHECK_EQ(pad->operand_count(), 2); + // CPU backend does not properly handle negative padding but this is ok // because negative padding should be removed by the algebraic simplifier. for (auto& padding_dimension : pad->padding_config().dimensions()) { @@ -2312,15 +2328,22 @@ absl::Status IrEmitter::HandlePad(HloInstruction* pad) { } } + const HloInstruction* padding_value = pad->operand(1); + const auto index_type = b()->getInt64Ty(); + const auto index = llvm_ir::IrArray::Index(index_type); + llvm::Value* padding_value_addr = padding_value_array.EmitArrayElementAddress( + index, b(), "padding_value_addr", true, nullptr); + const llvm_ir::ElementGenerator element_generator = + [this, padding_value, + padding_value_addr](const llvm_ir::IrArray::Index& target_index) { + return b()->CreateLoad(IrShapeType(padding_value->shape()), + padding_value_addr); + }; + // First, fill in the padding value to all output elements. TF_RETURN_IF_ERROR(EmitTargetElementLoop( - pad, "initialize", - [this, pad](const llvm_ir::IrArray::Index& target_index) { - const HloInstruction* padding_value = pad->operand(1); - llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); - return Load(IrShapeType(padding_value->shape()), padding_value_addr); - }, - std::nullopt)); + pad, "initialize", element_generator, + std::optional(output_array))); // Create a loop to iterate over the operand elements and update the output // locations where the operand elements should be stored. @@ -2332,7 +2355,6 @@ absl::Status IrEmitter::HandlePad(HloInstruction* pad) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); // Load an element from the operand. - llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, b()); @@ -2350,7 +2372,6 @@ absl::Status IrEmitter::HandlePad(HloInstruction* pad) { } // Store the operand element to the computed output location. - llvm_ir::IrArray output_array(GetIrArrayFor(pad)); llvm_ir::IrArray::Index output_index( output_multi_index, output_array.GetShape(), operand_index.GetType()); output_array.EmitWriteArrayElement(output_index, operand_data, b()); diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index 45a4c5e22af1f3..4ed7854f48a610 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -154,19 +154,22 @@ class IrEmitter : public DfsHloVisitorWithDefault, } // Used by IrEmitter2 - void PushComputeFunction(std::shared_ptr> b, - llvm::Module* llvm_module, + void PushComputeFunction(llvm::IRBuilder<>* b, llvm::Module* llvm_module, int64_t num_dynamic_loop_bounds, llvm::Function* function, llvm::Value* dynamic_loop_bounds_arg, llvm::BasicBlock* return_block) { - b->SetInsertPoint(llvm::BasicBlock::Create(llvm_module->getContext(), - "insertion_point", function)); - compute_function_.emplace(b.get(), llvm_module, num_dynamic_loop_bounds, - function, dynamic_loop_bounds_arg, return_block); + function->getEntryBlock().getTerminator()->eraseFromParent(); + b->SetInsertPoint(&function->getEntryBlock()); + compute_function_.emplace(b, llvm_module, num_dynamic_loop_bounds, function, + dynamic_loop_bounds_arg, return_block); } - void PopComputeFunction() { compute_function_.pop(); } + void PopComputeFunction() { + // At this point, the compute function destructor adds a branch to the + // return block. + compute_function_.pop(); + } // Emit an LLVM global variable for every constant buffer allocation. absl::Status EmitConstantGlobals(); @@ -294,6 +297,11 @@ class IrEmitter : public DfsHloVisitorWithDefault, absl::Status Preprocess(HloInstruction* hlo) override; absl::Status Postprocess(HloInstruction* hlo) override; + absl::Status HandlePad(HloInstruction* pad, + const llvm_ir::IrArray& operand_array, + const llvm_ir::IrArray& padding_value_array, + const llvm_ir::IrArray& output_array); + absl::Status HandleSelectAndScatter(HloInstruction* select_and_scatter, const llvm_ir::IrArray& operand_array, const llvm_ir::IrArray& source_array, diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index be02951ddcfd7e..70644ed173a15e 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -253,6 +253,37 @@ absl::StatusOr IrEmitter2::EmitElementalHostKernel( kernel_prototype.function->getName().str(), se::BlockDim(), thread_dims}); } +absl::StatusOr IrEmitter2::EmitPadHostKernel( + const HloInstruction* pad) { + VLOG(2) << "Emit Pad host kernel."; + + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(pad)); + + llvm_ir::IrArray operand_array = kernel_prototype.arguments[0]; + llvm_ir::IrArray padvalue_array = kernel_prototype.arguments[1]; + llvm_ir::IrArray output_array = kernel_prototype.results[0]; + + llvm::LLVMContext& ctx = module_->getContext(); + llvm::IRBuilder<> b(ctx); + auto builder_overwrite = nested_ir_emitter_->WithBuilder(b); + + nested_ir_emitter_->PushComputeFunction( + &b, module_, + /*num_dynamic_loop_bounds=*/0, kernel_prototype.function, + /*dynamic_loop_bounds_arg=*/nullptr, kernel_prototype.return_block); + + TF_RETURN_IF_ERROR(nested_ir_emitter_->HandlePad( + const_cast(pad), operand_array, padvalue_array, + output_array)); + + nested_ir_emitter_->PopComputeFunction(); + + return kernels_.emplace_back( + KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), + se::ThreadDim()}); +} + absl::StatusOr IrEmitter2::EmitFusionHostKernel( const HloFusionInstruction* fusion) { VLOG(2) << "Emit fusion host kernel: " << fusion->name(); @@ -809,11 +840,21 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( // Return null pointer to signal success as we do not support error handling // in the compiled host kernel. + llvm::BasicBlock* return_block = + llvm::BasicBlock::Create(ctx, "return", function); + + b.CreateBr(return_block); + + b.SetInsertPoint(return_block); b.CreateRet( llvm::ConstantPointerNull::get(llvm::PointerType::getUnqual(ctx))); - return KernelPrototype{function, kernel_thread_dims, kernel_thread, - std::move(ir_arguments), std::move(ir_results)}; + return KernelPrototype{function, + return_block, + kernel_thread_dims, + kernel_thread, + std::move(ir_arguments), + std::move(ir_results)}; } absl::StatusOr IrEmitter2::EmitKernelPrototype( diff --git a/xla/service/cpu/ir_emitter2.h b/xla/service/cpu/ir_emitter2.h index 65f07836e04ca0..c998840a24b330 100644 --- a/xla/service/cpu/ir_emitter2.h +++ b/xla/service/cpu/ir_emitter2.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" @@ -90,6 +91,7 @@ class IrEmitter2 { // to emit the actual kernel body. struct KernelPrototype { llvm::Function* function; + llvm::BasicBlock* return_block; // LLVM values identifying kernel invocation thread coordinates. KernelThreadDims thread_dims; @@ -123,6 +125,9 @@ class IrEmitter2 { absl::StatusOr EmitElementalHostKernel( const HloInstruction* instr); + // Emits a host kernel for the pad instruction. + absl::StatusOr EmitPadHostKernel(const HloInstruction* pad); + // Emits a host kernel for the given fusion instruction. absl::StatusOr EmitFusionHostKernel( const HloFusionInstruction* fusion); diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 4f3d45e4b47307..7d3c9c558021d4 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -265,9 +265,8 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( case HloOpcode::kCollectivePermute: return EmitCollectivePermuteThunk(instruction); - // TODO(ezhulenev): Port pad optimizations from IrEmitter. case HloOpcode::kPad: - return EmitElementalKernelThunk(instruction); + return EmitPadKernelThunk(instruction); case HloOpcode::kSlice: case HloOpcode::kDynamicSlice: @@ -615,6 +614,17 @@ absl::StatusOr ThunkEmitter::EmitElementalKernelThunk( kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); } +absl::StatusOr ThunkEmitter::EmitPadKernelThunk( + const HloInstruction* instruction) { + const HloPadInstruction* padInstr = Cast(instruction); + TF_ASSIGN_OR_RETURN(auto kernel, ir_emitter_.EmitPadHostKernel(padInstr)); + TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(padInstr)); + + return ThunkSequence::Of( + ThunkInfo(padInstr), buffers.arguments, buffers.results, kernel.name, + kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); +} + absl::StatusOr ThunkEmitter::EmitFusionKernelThunk( const HloInstruction* instruction) { auto* fusion = Cast(instruction); diff --git a/xla/service/cpu/thunk_emitter.h b/xla/service/cpu/thunk_emitter.h index 605b87578b6fad..6921f76e75179b 100644 --- a/xla/service/cpu/thunk_emitter.h +++ b/xla/service/cpu/thunk_emitter.h @@ -102,6 +102,9 @@ class ThunkEmitter { absl::StatusOr EmitElementalKernelThunk( const HloInstruction* instruction); + absl::StatusOr EmitPadKernelThunk( + const HloInstruction* instruction); + absl::StatusOr EmitFftThunk(const HloInstruction* instruction); absl::StatusOr EmitFusionKernelThunk( From 5ffa9554fa12b7aabba9413d87be33de5557308b Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 25 Jul 2024 10:34:43 -0700 Subject: [PATCH 162/376] [XLA:GPU] Add convenience test function for `ExecuteReplicated` Add convenience function to pass different inputs to different replicas of the same program. PiperOrigin-RevId: 656004089 --- xla/tests/BUILD | 4 +- xla/tests/collective_ops_test.cc | 140 ++++++++++++------ .../collective_pipeline_parallelism_test.cc | 40 +++-- xla/tests/hlo_test_base.cc | 23 +++ xla/tests/hlo_test_base.h | 7 + 5 files changed, 141 insertions(+), 73 deletions(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 643d44201a8b30..81f107a3c172d6 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2318,7 +2318,7 @@ xla_test( args = ["--xla_force_host_platform_device_count=4"], backend_tags = { # This test is tagged "manual" because it requires multiple GPUs, and Forge only supports - # single-GPU tests. Guitar skips "manual" tests unless they're also tagged "guitar". + # single-GPU tests. Guitar skips "manual" tests unless they're also tagged "guitar". "gpu": [ "guitar", "manual", @@ -2349,9 +2349,9 @@ xla_test( "//xla/service:executable", "//xla/service:hlo_module_config", "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 924a4b8c1775cc..0d8c3062bf2238 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -315,7 +315,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) { auto module = ParseAndReturnVerifiedModule(hlo_module, config).value(); TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, + ExecuteReplicated(std::move(module), absl::Span{}, /*num_replicas=*/2, /*use_threads=*/true, /*run_hlo_passes=*/true)); for (int replica_idx = 0; replica_idx < 2; replica_idx++) { @@ -356,7 +356,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) { auto module = ParseAndReturnVerifiedModule(hlo_module, config).value(); TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, + ExecuteReplicated(std::move(module), absl::Span{}, /*num_replicas=*/2, /*use_threads=*/true, /*run_hlo_passes=*/true)); for (int replica_idx = 0; replica_idx < 2; replica_idx++) { @@ -549,7 +549,8 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_Degenerate) { ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, /*num_replicas=*/kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + /*num_replicas=*/kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); @@ -581,7 +582,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduce)) { ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, num_devices_, + ExecuteReplicated(std::move(module), absl::Span{}, + num_devices_, /*use_threads=*/true, /*run_hlo_passes=*/false)); ASSERT_EQ(results.size(), num_devices_); @@ -616,7 +618,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduceTwoOperands)) { ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, num_devices_, + ExecuteReplicated(std::move(module), absl::Span{}, + num_devices_, /*use_threads=*/true, /*run_hlo_passes=*/false)); ASSERT_EQ(results.size(), num_devices_); @@ -648,7 +651,8 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, num_devices_, + ExecuteReplicated(std::move(module), absl::Span{}, + num_devices_, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), num_devices_); @@ -683,9 +687,11 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_Simple)) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), results[0])); @@ -719,7 +725,8 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), @@ -755,7 +762,8 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degenerate) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({10, 10}), @@ -790,7 +798,8 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_NotDegenerate) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({10, 10}), @@ -826,7 +835,8 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Rotate) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({13, 13}), @@ -863,7 +873,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncCollectivePermute)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/false)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), @@ -904,7 +915,8 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_EmptyReplicaGroups) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16, 12, 17, 13, 18}, @@ -949,7 +961,8 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_OrderedReplicaGroups) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({43, 48, 42, 47, 41, 46, 40, 45}, @@ -988,7 +1001,8 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_TwoReplicaGroups) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({23, 28, 20, 25}, results[0]); @@ -1019,7 +1033,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16, 12, 17, 13, 18}, @@ -1052,7 +1067,8 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); for (const Literal& result : results) { @@ -1080,7 +1096,8 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0_UseGlobalDevices) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); for (const Literal& result : results) { @@ -1108,7 +1125,8 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); for (const Literal& result : results) { @@ -1192,7 +1210,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGatherMixedTypes)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) { auto rs = results[replica_idx].DecomposeTuple(); @@ -1233,7 +1252,8 @@ XLA_TEST_F(CollectiveOpsTest, ReduceScatter) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); LiteralTestUtil::ExpectR1Equal({11, 13, 15, 17}, results[0]); LiteralTestUtil::ExpectR1Equal({19, 21, 23, 25}, results[1]); @@ -1306,7 +1326,8 @@ XLA_TEST_F(CollectiveOpsTest, ReduceScatter_Dim1) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); LiteralTestUtil::ExpectR1Equal({11, 13, 19, 21}, results[0]); LiteralTestUtil::ExpectR1Equal({15, 17, 23, 25}, results[1]); @@ -1348,7 +1369,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(ReduceScatterReassociate)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); const ErrorSpec es{1e-5, 1e-5}; @@ -1398,7 +1420,8 @@ XLA_TEST_F(CollectiveOpsTest, TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); const ErrorSpec es{1e-5, 1e-5}; @@ -1442,7 +1465,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduceReassociate)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); const ErrorSpec es{1e-5, 1e-5}; @@ -1478,7 +1502,8 @@ XLA_TEST_F(CollectiveOpsTest, TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1])); @@ -1520,7 +1545,8 @@ XLA_TEST_F(CollectiveOpsTest, TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1])); LiteralTestUtil::ExpectR3Equal({{{1, 2, 3}, @@ -1563,7 +1589,8 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_16BitInt) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); for (const Literal& result : results) { @@ -1591,7 +1618,8 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_16BitInt) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({10, 11}, results[0]); @@ -1618,7 +1646,8 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_16BitInt) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({11, 16}, results[0]); @@ -1652,7 +1681,8 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_16BitInt) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); for (const Literal& result : results) { @@ -1687,7 +1717,8 @@ XLA_TEST_F(CollectiveOpsTest, ReduceScatter_16BitInt) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({21}, results[0]); @@ -1721,7 +1752,8 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); const bfloat16 one = static_cast(1.0f); @@ -1747,7 +1779,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); for (const Literal& result : results) { @@ -1771,7 +1804,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); @@ -1794,7 +1828,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); @@ -1822,7 +1857,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/false)); ASSERT_EQ(results.size(), kNumReplicas); for (const Literal& result : results) { @@ -1869,7 +1905,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncReduceScatter)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/false)); LiteralTestUtil::ExpectR1Equal({11, 13, 15, 17}, results[0]); LiteralTestUtil::ExpectR1Equal({19, 21, 23, 25}, results[1]); @@ -1901,7 +1938,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllToAll)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/false)); ASSERT_EQ(results.size(), kNumReplicas); LiteralTestUtil::ExpectR1Equal({10, 11}, results[0]); @@ -1931,7 +1969,8 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1UnitDimensions) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/true)); ASSERT_EQ(results.size(), kNumReplicas); for (const Literal& result : results) { @@ -1971,9 +2010,11 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_Simple)) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN(std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, - /*use_threads=*/true)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), results[0])); @@ -2051,7 +2092,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_TwoConcurrentChains)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/false)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({3, 3}), @@ -2129,7 +2171,8 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_ValidationAttr1)) { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/false)); ASSERT_EQ(results.size(), kNumReplicas); // Skip checking the result for device 0 as it has garabage value as the @@ -2229,7 +2272,8 @@ body { TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated(std::move(module), {}, kNumReplicas, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, /*run_hlo_passes=*/false)); ASSERT_EQ(results.size(), kNumReplicas); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({2, 2}), diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index c5fe51401e8945..3d828edefba347 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -20,17 +20,15 @@ limitations under the License. #include #include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/platform/statusor.h" // Tests cross-GPU operations. // @@ -98,29 +96,25 @@ XLA_TEST_F(CollectivePipelineParallelismTest, TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(kModuleStr, config)); - // Input for replica i is - // {{i, i}, - // {i, i}}. - std::vector replica_inputs; - for (float i = 1; i < kNumReplicas + 1; ++i) { - replica_inputs.push_back({LiteralUtil::CreateR2({{i, i}, {i, i}})}); - replica_inputs.push_back(LiteralUtil::CreateR2({{0, 0}, {0, 1}})); + // Inputs for replica i are + // A = {{i+1, i+1}, + // {i+1, i+1}}, and + // B = {{0, 0}, + // {0, 1}}. + std::vector inputs_a; + for (int64_t i = 0; i < kNumReplicas; ++i) { + float val = i + 1; + inputs_a.push_back(LiteralUtil::CreateR2({{val, val}, {val, val}})); + } + Literal input_b_replicated = LiteralUtil::CreateR2({{0, 0}, {0, 1}}); + std::vector> inputs; + for (int64_t i = 0; i < kNumReplicas; ++i) { + inputs.push_back({&inputs_a[i], &input_b_replicated}); } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, - test_runner_.CreateExecutable( - std::unique_ptr(std::move(module)), - /*run_hlo_passes=*/true)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, - ExecuteReplicated( - /*executable_provider=*/[&](int64_t) { return executable.get(); }, - /*argument_count_provider=*/[](int64_t) { return 2; }, - /*argument_provider=*/ - [&](int64_t replica, int64_t index) -> const Literal* { - return &replica_inputs[replica * 2 + index]; - }, - kNumReplicas, /*run_hlo_passes=*/true, - /*device_assignment=*/nullptr)); + ExecuteReplicated(std::move(module), inputs, kNumReplicas, + /*run_hlo_passes=*/true)); LiteralTestUtil::ExpectR2Equal({{0, 0}, {2, 2}}, results[0]); LiteralTestUtil::ExpectR2Equal({{0, 0}, {3, 3}}, results[1]); LiteralTestUtil::ExpectR2Equal({{0, 0}, {4, 4}}, results[2]); diff --git a/xla/tests/hlo_test_base.cc b/xla/tests/hlo_test_base.cc index 560a486a378987..e7367e75a760b9 100644 --- a/xla/tests/hlo_test_base.cc +++ b/xla/tests/hlo_test_base.cc @@ -411,6 +411,29 @@ absl::StatusOr> HloTestBase::ExecuteReplicated( options, device_assignment); } +absl::StatusOr> HloTestBase::ExecuteReplicated( + std::unique_ptr module, + std::vector> arguments, int64_t num_replicas, + bool run_hlo_passes) { + CHECK(num_replicas > 0 && "expect at least one replica"); + CHECK(num_replicas == arguments.size() && + "expect arguments for each replica"); + int64_t argument_count = arguments.front().size(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + runner_->CreateExecutable(std::unique_ptr(std::move(module)), + run_hlo_passes)); + return ExecuteReplicated( + /*executable_provider=*/[&](int64_t) { return executable.get(); }, + /*argument_count_provider=*/[&](int64_t) { return argument_count; }, + /*argument_provider=*/ + [&](int64_t replica_idx, int64_t argument_idx) -> const Literal* { + return arguments[replica_idx][argument_idx]; + }, + num_replicas, /*run_hlo_passes=*/run_hlo_passes, + /*device_assignment=*/nullptr); +} + absl::StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { diff --git a/xla/tests/hlo_test_base.h b/xla/tests/hlo_test_base.h index 9e90eac54cb576..9858ed6f53997d 100644 --- a/xla/tests/hlo_test_base.h +++ b/xla/tests/hlo_test_base.h @@ -246,6 +246,13 @@ class HloTestBase : public ManifestCheckingTest { int64_t num_replicas, bool run_hlo_passes, DeviceAssignment* device_assignment = nullptr); + // Convenience function for above. Allows passing different inputs to + // different replicas of the same program. + absl::StatusOr> ExecuteReplicated( + std::unique_ptr module, + std::vector> arguments, int64_t num_replicas, + bool run_hlo_passes); + // Executes the given hlo module on two backends and compares results. // // 'arguments': the input of the hlo module. From 45dca1a0a1d87f3d3c93fa4175e1df971acddb10 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 25 Jul 2024 10:43:44 -0700 Subject: [PATCH 163/376] [xla:cpu] Correctly resolve device ordinal from parent stream For consistency with current XLA:CPU, always use parent stream to resolve device ordinal. PiperOrigin-RevId: 656007373 --- xla/service/cpu/BUILD | 1 + xla/service/cpu/cpu_executable.cc | 3 ++- xla/service/cpu/cpu_instruction_fusion_test.cc | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 4cbc0e9e5439fe..eef125f44cfa5f 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1377,6 +1377,7 @@ xla_cc_test( deps = [ ":cpu_instruction_fusion", "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/service:transpose_folding", "//xla/tests:hlo_test_base", diff --git a/xla/service/cpu/cpu_executable.cc b/xla/service/cpu/cpu_executable.cc index b1a8685a84d72a..ee843dc147eb0b 100644 --- a/xla/service/cpu/cpu_executable.cc +++ b/xla/service/cpu/cpu_executable.cc @@ -387,7 +387,8 @@ absl::Status CpuExecutable::ExecuteThunks( Thunk::ExecuteParams execute_params = { &*function_registry_, &allocations, - runtime::GetXfeedManager(run_options->device_ordinal()), + runtime::GetXfeedManager( + run_options->stream()->parent()->device_ordinal()), run_options->intra_op_thread_pool(), &task_runner, &collective_execute_params, diff --git a/xla/service/cpu/cpu_instruction_fusion_test.cc b/xla/service/cpu/cpu_instruction_fusion_test.cc index de8edeb1096759..5db0bebaaa9a2e 100644 --- a/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/transpose_folding.h" #include "xla/shape.h" From a5f6630e8f701352872d73db8264e0540047c910 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 25 Jul 2024 10:50:16 -0700 Subject: [PATCH 164/376] [JAX] Do not skip array copies whenever source or destination memory is default If `jax.device_put()` copies arrays from/to non-default memory, we should not skip the array copy just because destination/source memory (respectively) is a default memory. By canonicalizing memory kinds when doing this filtering, we can skip array copies only when it is fine. PiperOrigin-RevId: 656009820 --- xla/python/py_array.cc | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/xla/python/py_array.cc b/xla/python/py_array.cc index 4beebf05f59304..8d00206e38ee30 100644 --- a/xla/python/py_array.cc +++ b/xla/python/py_array.cc @@ -1063,14 +1063,12 @@ absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( const ifrt::DeviceList& src_devices = ifrt_array_ptr->sharding().devices(); const ifrt::DeviceList& dst_devices = dst_device_lists[i]; - ifrt::MemoryKind src_memory_kind = ifrt_array_ptr->sharding().memory_kind(); - ifrt::MemoryKind dst_memory_kind = - CreateIfRtMemoryKindFromSharding(dst_sharding); - - if (src_devices == dst_devices && - (!dst_memory_kind.memory_kind().has_value() || - !src_memory_kind.memory_kind().has_value() || - src_memory_kind == dst_memory_kind)) { + ifrt::MemoryKind src_memory_kind = ifrt::CanonicalizeMemoryKind( + ifrt_array_ptr->sharding().memory_kind(), src_devices.front()); + ifrt::MemoryKind dst_memory_kind = ifrt::CanonicalizeMemoryKind( + CreateIfRtMemoryKindFromSharding(dst_sharding), dst_devices.front()); + + if (src_devices == dst_devices && src_memory_kind == dst_memory_kind) { results[i] = py_arrays[i]; continue; } From c59875f4ff6ea10a74c58284714cc2f85815d0b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Thu, 25 Jul 2024 12:38:40 -0700 Subject: [PATCH 165/376] [XLA:CPU] Fix LLVM compiler test for thunks runtime. PiperOrigin-RevId: 656047728 --- xla/tests/BUILD | 1 + xla/tests/llvm_compiler_test.cc | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 81f107a3c172d6..5124ca126d286b 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2647,6 +2647,7 @@ xla_test( "cpu", "gpu", ], + tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", "//xla:literal_util", diff --git a/xla/tests/llvm_compiler_test.cc b/xla/tests/llvm_compiler_test.cc index bf0e52d59f55f3..94e37c64664948 100644 --- a/xla/tests/llvm_compiler_test.cc +++ b/xla/tests/llvm_compiler_test.cc @@ -41,10 +41,12 @@ namespace { using LLVMCompilerTest = HloTestBase; const char* const kHloText = R"( -HloModule Constant +HloModule Add ENTRY main { - ROOT constant = f32[] constant(42.0) + constant.0 = f32[] constant(42.0) + constant.1 = f32[] constant(43.0) + ROOT add.0 = f32[] add(constant.0, constant.1) } )"; @@ -61,8 +63,12 @@ TEST_F(LLVMCompilerTest, HooksTest) { return absl::OkStatus(); }; - // Create HLO module, and run the compiler. + // Create HLO module. Note this module needs to consist of at least one + // instruction that is compiled using LLVM (e.g. for CPU thunks runtime it is + // 'add' instruction), otherwise the hooks are never called. auto hlo_module = ParseAndReturnVerifiedModule(kHloText).value(); + + // Create and run the compiler. LLVMCompiler* compiler = tensorflow::down_cast(backend().compiler()); compiler->SetPreOptimizationHook(pre_opt_hook); From 29c2dbe5f4d237d5ae9c18f61d9701d80e1f4946 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Thu, 25 Jul 2024 12:47:29 -0700 Subject: [PATCH 166/376] Reenable clang_format.yml This became possible after https://github.com/openxla/xla/commit/8ed81337532dec2f4ee3e63dc44d327980763449 removed comments which previously confused clang-format. FORCE_TEST_ACTIONS PiperOrigin-RevId: 656050329 --- .github/workflows/clang_format.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/clang_format.yml b/.github/workflows/clang_format.yml index 0bb143c3c9fc1f..2701311d047371 100644 --- a/.github/workflows/clang_format.yml +++ b/.github/workflows/clang_format.yml @@ -26,6 +26,7 @@ jobs: shell: bash timeout-minutes: 1 if: | + github.event.sender.type == 'User' || contains(github.event.pull_request.body, 'FORCE_TEST_ACTIONS') steps: - name: "Checking out repository" From 6b0495fc43a1007b24590d6f0609d4440acd7e7f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Jul 2024 13:19:38 -0700 Subject: [PATCH 167/376] [XLA:MSA] Implement a simulator to predict the HLO module execution time by taking asynchronous copies into account. The off-the-shelf runtime predictor does not include the overhead of asynchronous copies. To provide a more accurate estimation, we implement another simulator which includes the overhead of asynchronous copies for runtime prediction. This function simulates the default memory system to measure the overhead of asynchronous copies. Here is the overview of the new simulator: ```c func SimulateElapsedTime(): read_queue = [] write_queue = [] total_elapsed = 0 for each instruction: if instruction is copy-start: if instruction is default to alternate memory: read_queue.push(instruction) else if instruction is alternate to default memory: write_queue.push(instruction) else if instruction is copy-done: # pop instruction from the read/write queue and calculate the # execution time of the async copy total_elapsed += SimulateAsyncCopyDone() else if instruction is compute: # Same as the cost analysis without async-overhead. # Except we also calculate the default idle window, # and process outstanding async copy instructions in the window. total_elapsed += SimulateComputeInstruction() end if end for return total_elapsed ``` PiperOrigin-RevId: 656064401 --- xla/service/memory_space_assignment/BUILD | 1 + .../memory_space_assignment.cc | 19 ++- .../memory_space_assignment_test.cc | 4 +- .../memory_space_assignment/simulator.cc | 130 +++++++++++++--- .../memory_space_assignment/simulator.h | 39 ++++- .../memory_space_assignment/simulator_test.cc | 146 +++++++++++++++--- 6 files changed, 287 insertions(+), 52 deletions(-) diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index cd7a0c0f1163af..56bf642315feca 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -331,6 +331,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", + "//xla/service:hlo_alias_analysis", "//xla/service:hlo_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/service/memory_space_assignment/memory_space_assignment.cc b/xla/service/memory_space_assignment/memory_space_assignment.cc index 62d434983bd65a..f2781c6c0f9a41 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -353,12 +353,15 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( const HloAliasAnalysis& alias_analysis) { TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis)); + std::optional runtime_simulator = std::nullopt; if (options_.cost_analysis) { - RuntimeSimulator runtime_simulator(options_.cost_analysis, - options_.alternate_memory_space); - float estimated_time = runtime_simulator.ComputeEstimatedElapsedTime( - hlo_live_range, allocations_); - VLOG(1) << "Estimated elapsed time (sec): " << estimated_time; + runtime_simulator.emplace(options_.cost_analysis, + options_.alternate_memory_space); + float estimated_time = + runtime_simulator->SimulateElapsedTimeWithoutAsyncCopies(hlo_live_range, + allocations_); + VLOG(1) << "Estimated elapsed time without async copies (sec): " + << estimated_time; } TF_RETURN_IF_ERROR(Process(hlo_live_range)); @@ -366,6 +369,12 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( TF_RETURN_IF_ERROR(SimplifyGraph()); TF_RETURN_IF_ERROR(FixSchedule()); TF_RETURN_IF_ERROR(ExportAndColorBuffers()); + if (runtime_simulator.has_value()) { + float estimated_time = + runtime_simulator->SimulateElapsedTime(module_, allocations_); + VLOG(1) << "Estimated elapsed time with async copies (sec): " + << estimated_time; + } if (VLOG_IS_ON(3)) { LOG(INFO) << "Module after memory space assignment: "; diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 3888dfb47fbe0c..7547901d1aa4a8 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -9925,7 +9925,9 @@ ENTRY main { options.max_size_in_bytes = 300; // Setup cost analysis so it takes 2 instructions to prefetch anything. - HloCostAnalysis hlo_cost_analysis(ShapeSize); + HloCostAnalysis::Properties properties; + properties[HloCostAnalysis::kBytesAccessedKey] = kBytesPerSecond; + HloCostAnalysis hlo_cost_analysis(ShapeSize, properties); CostAnalysisOptions cost_analysis_options; HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); TF_ASSERT_OK_AND_ASSIGN( diff --git a/xla/service/memory_space_assignment/simulator.cc b/xla/service/memory_space_assignment/simulator.cc index 7cd50834e16d29..b618061ec1c6ef 100644 --- a/xla/service/memory_space_assignment/simulator.cc +++ b/xla/service/memory_space_assignment/simulator.cc @@ -21,17 +21,18 @@ limitations under the License. #include #include #include -#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/layout.h" +#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/shape_util.h" @@ -40,30 +41,30 @@ limitations under the License. namespace xla { namespace memory_space_assignment { -float RuntimeSimulator::ComputeEstimatedElapsedTime( - const HloLiveRange& hlo_live_range, const AllocationSequence& allocations) { - absl::flat_hash_map> - outputs_in_alternate_memory_map; - absl::flat_hash_map>> - operands_in_alternate_memory_map; - +void RuntimeSimulator::InitializeAlternateMemoryMap( + const AllocationSequence& allocations) { + outputs_in_alternate_memory_map_.clear(); + operands_in_alternate_memory_map_.clear(); for (auto& allocation : allocations) { if (!allocation->is_copy_allocation()) { if (allocation->memory_space() == MemorySpace::kAlternate) { const HloInstruction* defining_instruction = allocation->defining_position().instruction; - outputs_in_alternate_memory_map[defining_instruction].push_back( + outputs_in_alternate_memory_map_[defining_instruction].push_back( allocation->defining_position().index); } } for (auto& hlo_use : allocation->uses()) { const HloInstruction* use_instruction = hlo_use.instruction; - operands_in_alternate_memory_map[use_instruction].push_back( + operands_in_alternate_memory_map_[use_instruction].push_back( std::make_pair(hlo_use.operand_number, hlo_use.operand_index)); } } +} +float RuntimeSimulator::SimulateElapsedTimeWithoutAsyncCopies( + const HloLiveRange& hlo_live_range, const AllocationSequence& allocations) { + InitializeAlternateMemoryMap(allocations); const auto& instruction_sequence = hlo_live_range.flattened_instruction_sequence().instructions(); float total_elapsed = 0.0; @@ -71,15 +72,18 @@ float RuntimeSimulator::ComputeEstimatedElapsedTime( if (instruction->opcode() == HloOpcode::kWhile) { continue; } - std::vector outputs_in_alternate_memory; - auto output_it = outputs_in_alternate_memory_map.find(instruction); - if (output_it != outputs_in_alternate_memory_map.end()) { - outputs_in_alternate_memory = output_it->second; + + absl::Span outputs_in_alternate_memory; + auto output_it = outputs_in_alternate_memory_map_.find(instruction); + if (output_it != outputs_in_alternate_memory_map_.end()) { + outputs_in_alternate_memory = absl::MakeSpan(output_it->second); } - std::vector> operands_in_alternate_memory; - auto operand_it = operands_in_alternate_memory_map.find(instruction); - if (operand_it != operands_in_alternate_memory_map.end()) { - operands_in_alternate_memory = operand_it->second; + + absl::Span> + operands_in_alternate_memory; + auto operand_it = operands_in_alternate_memory_map_.find(instruction); + if (operand_it != operands_in_alternate_memory_map_.end()) { + operands_in_alternate_memory = absl::MakeSpan(operand_it->second); } float instruction_elapsed_per_invoke = @@ -269,5 +273,91 @@ void RuntimeSimulator::ProcessAsyncCopiesInIdleTime(float time) { } } +float RuntimeSimulator::SimulateElapsedTime( + const HloModule* hlo_module, const AllocationSequence& allocations) { + InitializeAlternateMemoryMap(allocations); + + std::unique_ptr alias_analysis = + HloAliasAnalysis::Run(hlo_module).value(); + std::unique_ptr hlo_live_range = + HloLiveRange::Run(hlo_module->schedule(), *alias_analysis, + hlo_module->entry_computation()) + .value(); + + // Cannot provide a valid result if the bandwidth is invalid. + CHECK_GT(cost_analysis_->base_costs().BytesPerSecond(), 0.0); + + float total_elapsed = 0.0; + + const auto& instruction_sequence = + hlo_live_range->flattened_instruction_sequence().instructions(); + for (const HloInstruction* instruction : instruction_sequence) { + float inst_elapsed = 0.0; + if (instruction->opcode() == HloOpcode::kWhile) { + // Since the instructions in the while body are calculated + // separately, we can skip the while instruction. + continue; + } + if (instruction->parent()->IsAsyncComputation()) { + // We assume the overhead of async computations can be hidden perfectly. + // We plan to integrate the async copy overhead analysis later + // (b/351913186). + continue; + } + if (instruction->opcode() == HloOpcode::kCopyStart) { + // Try to categorize the async copy instruction into + // read-from-default and write-to-default queues. + MemoryTransferDirection direction = + GetAsyncCopyDirection(instruction, alternate_memory_space_); + const Shape& transfer_shape = instruction->operand(0)->shape(); + float transfer_bytes = static_cast( + cost_analysis_->base_costs().GetShapeSize(transfer_shape)); + if (direction == MemoryTransferDirection::kDefaultToAlternate) { + outstanding_read_default_queue_.push_back( + OutstandingAsyncCopy{instruction, transfer_bytes}); + } else if (direction == MemoryTransferDirection::kAlternateToDefault) { + outstanding_write_default_queue_.push_back( + OutstandingAsyncCopy{instruction, transfer_bytes}); + } else { + // The copy does not involve default memory. + } + } else if (instruction->opcode() == HloOpcode::kCopyDone) { + inst_elapsed = SimulateAsyncCopyDone(instruction); + } else { + // This branch is for the compute instructions. + // TODO(b/351913186): Plan to add another branch to handle async + // copy instructions caused by slicing. + + absl::Span outputs_in_alternate_memory; + auto output_it = outputs_in_alternate_memory_map_.find(instruction); + if (output_it != outputs_in_alternate_memory_map_.end()) { + outputs_in_alternate_memory = absl::MakeSpan(output_it->second); + } + + absl::Span> + operands_in_alternate_memory; + auto operand_it = operands_in_alternate_memory_map_.find(instruction); + if (operand_it != operands_in_alternate_memory_map_.end()) + operands_in_alternate_memory = absl::MakeSpan(operand_it->second); + + inst_elapsed = + SimulateComputeInstruction(instruction, operands_in_alternate_memory, + outputs_in_alternate_memory); + } + if (inst_elapsed > 0.0) { + // The calculation assumes all instructions are executed independently. + // Thus, the execution time is the same for each invocation. This property + // is not hold for all cases. For example, if an async copies are + // outstanding before the loop, and there are other async copies inside + // the loop body. In this case, the first async copy in the first + // iteration will be slower than other iterations, since it needs to wait + // for the async copies issued before the loop. + float total_trip_count = cost_analysis_->CalculateNestTripCount( + instruction, &cost_analysis_cache_); + total_elapsed += inst_elapsed * total_trip_count; + } + } + return total_elapsed; +} } // namespace memory_space_assignment } // namespace xla diff --git a/xla/service/memory_space_assignment/simulator.h b/xla/service/memory_space_assignment/simulator.h index 900a2e0593c741..729b220760f85e 100644 --- a/xla/service/memory_space_assignment/simulator.h +++ b/xla/service/memory_space_assignment/simulator.h @@ -19,7 +19,9 @@ limitations under the License. #include #include #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" @@ -73,12 +75,23 @@ class RuntimeSimulator { ~RuntimeSimulator() = default; - // This function is used to predict the effectiveness of the memory space - // assignment solution. Specifically, it returns the estimated execution time - // (in seconds) of the HLO module for the given memory space assignment (i.e., - // ```allocations```). - float ComputeEstimatedElapsedTime(const HloLiveRange& hlo_live_range, - const AllocationSequence& allocations); + // This function provides a basic estimate without considering the overhead of + // async copies. + float SimulateElapsedTimeWithoutAsyncCopies( + const HloLiveRange& hlo_live_range, + const AllocationSequence& allocations); + + // Returns the time to simulate the hlo_live_range, when we account for the + // waiting time for async copies to finish. + // + // To simulate the overhead of async copies, we need to maintain two queues to + // track the outstanding memory access requests that read/write the default + // memory. When we simulate compute, we use any time there is spare bandwidth + // to simulate async memory accesses to default memory. If we get to an async + // copy done, we must wait until it finishes (potentially waiting for copies + // issued before it to finish. + float SimulateElapsedTime(const HloModule* hlo_module, + const AllocationSequence& allocations); // This is an auxiliary function for simulating the execution // time for executing a copy-done instruction. It returns the @@ -114,9 +127,14 @@ class RuntimeSimulator { absl::Span outputs_in_alternate_memory); private: + // This function parses the memory space assignment solution and initializes + // the maps that record, for each instruction, which outputs and operands are + // stored in alternate memory. These maps are used to estimate the runtime of + // the HLO module. + void InitializeAlternateMemoryMap(const AllocationSequence& allocations); const CostAnalysis* cost_analysis_; CostAnalysis::Cache cost_analysis_cache_; - + // Members used for memory model simulation // This function updates the queue by updating the front request with the // processed bytes. If the request is completed (no remaining bytes to // process), the function returns the instruction and pop it from the queue. @@ -129,10 +147,15 @@ class RuntimeSimulator { // outstanding_*_default_queues are non-empty, they share bandwidth. If one of // the queues is empty and the other is not, it gets the full bandwdith. void ProcessAsyncCopiesInIdleTime(float time); - // Members used for memory model simulation + int64_t alternate_memory_space_; std::list outstanding_read_default_queue_; std::list outstanding_write_default_queue_; + absl::flat_hash_map> + outputs_in_alternate_memory_map_; + absl::flat_hash_map>> + operands_in_alternate_memory_map_; }; } // namespace memory_space_assignment diff --git a/xla/service/memory_space_assignment/simulator_test.cc b/xla/service/memory_space_assignment/simulator_test.cc index ad588538d7bf27..db6025f05ebf1f 100644 --- a/xla/service/memory_space_assignment/simulator_test.cc +++ b/xla/service/memory_space_assignment/simulator_test.cc @@ -79,6 +79,11 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { TF_ASSIGN_OR_RETURN( cost_analysis_, CostAnalysis::Create(*hlo_cost_analysis_costs_, _options, *module_)); + + TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module_.get())); + TF_ASSIGN_OR_RETURN(hlo_live_range_, + HloLiveRange::Run(module_->schedule(), *alias_analysis_, + module_->entry_computation())); runtime_simulator_ = std::make_unique( cost_analysis_.get(), kAlternateMemorySpace); return absl::OkStatus(); @@ -87,11 +92,14 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { std::unique_ptr hlo_cost_analysis_costs_; std::unique_ptr cost_analysis_; + std::unique_ptr alias_analysis_; + std::unique_ptr hlo_live_range_; + memory_space_assignment::AllocationSequence allocations_; std::unique_ptr runtime_simulator_; std::unique_ptr module_; }; -TEST_F(MemorySpaceAssignmentSimulatorTest, SingleLayerNestedLoop) { +TEST_F(MemorySpaceAssignmentSimulatorTest, SingleLayerLoop) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true @@ -118,25 +126,125 @@ TEST_F(MemorySpaceAssignmentSimulatorTest, SingleLayerNestedLoop) { )"; TF_ASSERT_OK(Initialize(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, - HloAliasAnalysis::Run(module_.get())); - TF_ASSERT_OK_AND_ASSIGN( - auto hlo_live_range, - HloLiveRange::Run(module_->schedule(), *alias_analysis, - module_->entry_computation())); - - // Since the HLO does not contain memory access, pass an empty allocation - // sequence for test. - memory_space_assignment::AllocationSequence allocations; // The total elapsed time is the summation of the elapsed time of each // instruction. Here are the overhead of each instruction (secs): // %increment: 12 * 42 // tuple(%constant.0): 8 * 1 // %greater: 9 * 42 // %loop_result: 8 * 42 - EXPECT_EQ(runtime_simulator_->ComputeEstimatedElapsedTime(*hlo_live_range, - allocations), + EXPECT_EQ(runtime_simulator_->SimulateElapsedTimeWithoutAsyncCopies( + *hlo_live_range_, allocations_), 1226); + EXPECT_EQ( + runtime_simulator_->SimulateElapsedTime(module_.get(), allocations_), + 1226); +} + +TEST_F(MemorySpaceAssignmentSimulatorTest, NestedLayerLoop) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + %inner.body { + %constant.1 = s32[] constant(1) + %param = (s32[]) parameter(0) + %count = s32[] get-tuple-element(%param), index=0 + %increment = s32[] add(s32[] %count, s32[] %constant.1) + ROOT %loop_result = (s32[]) tuple(%increment) + } + %inner.condition { + %param = (s32[]) parameter(0) + %constant.42 = s32[] constant(42) + %condition_input = s32[] get-tuple-element(%param), index=0 + ROOT %greater = pred[] compare(s32[] %constant.42, s32[] %condition_input), direction=GT + } + %outer.body { + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + %param = (s32[]) parameter(0) + %inner_while = (s32[]) while(tuple(%constant.0)), condition=%inner.condition, body=%inner.body + %count = s32[] get-tuple-element(%param), index=0 + %increment = s32[] add(s32[] %count, s32[] %constant.1) + ROOT %loop_result = (s32[]) tuple(%increment) + } + %outer.condition { + %param = (s32[]) parameter(0) + %constant.27 = s32[] constant(27) + %condition_input = s32[] get-tuple-element(%param), index=0 + ROOT %greater = pred[] compare(s32[] %constant.27, s32[] %condition_input), direction=GT + } + ENTRY Entry { + %constant.0 = s32[] constant(0) + ROOT %while_outer = (s32[]) while(tuple(%constant.0)), condition=%outer.condition, body=%outer.body + } + )"; + TF_ASSERT_OK(Initialize(hlo_string)); + // The inner loop is derived from the SingleLayerLoop test, whose overhead is + // 1226 seconds. + + // For the outer loop, the overhead of each instruction is: + // %increment: 12 * 27 + // tuple(%constant.0): 8 * 1 + // %greater: 9 * 27 + // %loop_result: 8 * 27 + // Thus, the total overhead of the while_outer is 1226 * 27 + 12 * 27 + 8 * 1 + // + 9 * 27 + 8 * 27 = 33893 + + EXPECT_EQ(runtime_simulator_->SimulateElapsedTimeWithoutAsyncCopies( + *hlo_live_range_, allocations_), + 33893); + EXPECT_EQ( + runtime_simulator_->SimulateElapsedTime(module_.get(), allocations_), + 33893); +} + +TEST_F(MemorySpaceAssignmentSimulatorTest, SingleAsyncCopyOverhead) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[1,1,1024,2048] parameter(0) + copy-start.1 = (f32[1,1,1024,2048]{0,1,2,3:S(1)}, f32[1,1,1024,2048], u32[]) copy-start(param_0) + ROOT copy-done.1 = f32[1,1,1024,2048]{0,1,2,3:S(1)} copy-done(copy-start.1) + } + + )"; + TF_ASSERT_OK(Initialize(hlo_string)); + + // Since the HLO does not contain memory access, pass an empty allocation + // sequence for test. + memory_space_assignment::AllocationSequence allocations; + // The SimulateElapsedTimeWithoutAsyncCopies should not include the overhead + // of async copies. + EXPECT_EQ(runtime_simulator_->SimulateElapsedTimeWithoutAsyncCopies( + *hlo_live_range_, allocations_), + 0); + // The expected elapsed time is 1024 * 2048 * 4 / 1 = 8388608. + EXPECT_EQ( + runtime_simulator_->SimulateElapsedTime(module_.get(), allocations_), + 8388608); +} + +TEST_F(MemorySpaceAssignmentSimulatorTest, AsyncCopyWithComputationOverhead) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[8] parameter(0) + param_1 = f32[2] parameter(1) + copy-start.1 = (f32[8]{0:S(1)}, f32[8], u32[]) copy-start(param_0) + neg_compute = f32[2] negate(param_1) + ROOT copy-done.1 = f32[8]{0:S(1)} copy-done(copy-start.1) + } + + )"; + TF_ASSERT_OK(Initialize(hlo_string)); + // The neg_compute read/write 16 bytes in total, thus, it requires 16 seconds + // for default memory access. Since it only requires 2 FLOPs computation which + // requires 2 seconds, it is a memory-bound instruction which does not have + // idle time to process async copies. + // Workflow: + // neg_compute: | 16 sec (memory-bound) | + // copy-done.1: | | read 32 bytes | + // time: | 16 sec | 32 sec | + EXPECT_EQ( + runtime_simulator_->SimulateElapsedTime(module_.get(), allocations_), 48); } class SimulateAsyncCopyDoneTest : public MemorySpaceAssignmentSimulatorTest { @@ -298,7 +406,8 @@ TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyTransferPartialProcess) { EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); } -TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesWithComputeInstruction) { +TEST_F(SimulateAsyncCopyDoneTest, + SimulateComputeInstructionWithSingleAsyncCopy) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { @@ -329,7 +438,8 @@ TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesWithComputeInstruction) { EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); } -TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesInTimeWithSharedBandwidth) { +TEST_F(SimulateAsyncCopyDoneTest, + SimulateComputeInstructionWithSharedBandwidth) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { @@ -368,7 +478,7 @@ TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesInTimeWithSharedBandwidth) { copy_start_2_inst, 96}})); } -TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesInTimeWithFullBandwidth) { +TEST_F(SimulateAsyncCopyDoneTest, SimulateComputeInstructionWithFullBandwidth) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { @@ -384,7 +494,7 @@ TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesInTimeWithFullBandwidth) { const HloInstruction* copy_start_1_inst = instruction_map_["copy-start.1"]; - // Same as the 'ProcessAsyncCopiesInTimeWithSharedBandwidth' test, there are + // Same as the 'SimulateComputeInstructionWithSharedBandwidth' test, there are // 64 secs idle time to process async copies. Since only the read queue is not // empty, we can use the full bandwidth and process 64 sec * 1 bytes/sec = 64 // bytes. @@ -400,7 +510,7 @@ TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopiesInTimeWithFullBandwidth) { EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); } -TEST_F(SimulateAsyncCopyDoneTest, ProcessAsyncCopyInTimeWithEmptyQueues) { +TEST_F(SimulateAsyncCopyDoneTest, SimulateComputeInstructionWithEmptyQueues) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { From fffade67c0730487a9836d220373bfc27ae9c2f5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 25 Jul 2024 14:07:51 -0700 Subject: [PATCH 168/376] Delete mesh.Loop now that xmap has been deleted PiperOrigin-RevId: 656084608 --- xla/python/sharding.cc | 12 ++---------- xla/python/xla_client.py | 2 +- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/xla/python/sharding.cc b/xla/python/sharding.cc index b1910d7aad7b61..acbe324ac75f47 100644 --- a/xla/python/sharding.cc +++ b/xla/python/sharding.cc @@ -191,16 +191,8 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); nb::module_ si = nb::module_::import_("jax._src.sharding_impls"); - // TODO(parkers): Once jax always has preprocess_with_manual, we can - // remove the fallback. - nb::object preprocess_fn; - try { - preprocess_fn = si.attr("preprocess_with_manual"); - } catch (nb::python_error& e) { - parsed_pspec_ = si.attr("preprocess")(mesh_, spec_, parsed_pspec_); - return; - } - parsed_pspec_ = preprocess_fn(mesh_, spec_, parsed_pspec_, manual_axes_); + parsed_pspec_ = + si.attr("preprocess")(mesh_, spec_, parsed_pspec_, manual_axes_); } SingleDeviceSharding::SingleDeviceSharding(nb::object device, diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index b0f0264162eb15..97f01bec9bb0d5 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 278 +_version = 279 # Version number for MLIR:Python components. mlir_api_version = 57 From 5656c30bb127d7304d78688ab839c57926d9ea7e Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 25 Jul 2024 14:14:46 -0700 Subject: [PATCH 169/376] [XLA:GPU] Add HLO-based test for naive implementation of pipeline parallelism - 4 devices - 4 microbatches - no circular repeat - no disabled collectives - no collective pipelining PiperOrigin-RevId: 656087514 --- xla/tests/BUILD | 4 + .../collective_pipeline_parallelism_test.cc | 172 +++++++++++++++++- 2 files changed, 173 insertions(+), 3 deletions(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 5124ca126d286b..2c225af4b8dffe 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2342,16 +2342,20 @@ xla_test( ":test_utils", ":verified_hlo_module", ":xla_internal_test_main", + "//xla:error_spec", "//xla:literal", "//xla:literal_util", "//xla:shape_util", + "//xla:statusor", "//xla/hlo/ir:hlo", "//xla/service:executable", "//xla/service:hlo_module_config", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index 3d828edefba347..abf0e4739fa3ca 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -22,13 +22,17 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/error_spec.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" // Tests cross-GPU operations. // @@ -71,8 +75,10 @@ XLA_TEST_F(CollectivePipelineParallelismTest, iter = u32[] get-tuple-element(param), index=0 data = f32[2,2] get-tuple-element(param), index=1 weights = f32[2,2] get-tuple-element(param), index=2 - matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, rhs_contracting_dims={0} - cp = f32[2,2] collective-permute(matmul), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + cp = f32[2,2] collective-permute(matmul), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} iter_increment = u32[] constant(1) next_iter = u32[] add(iter, iter_increment) ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, cp, weights) @@ -83,7 +89,8 @@ XLA_TEST_F(CollectivePipelineParallelismTest, data = f32[2,2] parameter(0) weights = f32[2,2] parameter(1) input = (u32[], f32[2,2], f32[2,2]) tuple(iter, data, weights) - while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body + while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, + body=while_body ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1 } )"; @@ -121,5 +128,164 @@ XLA_TEST_F(CollectivePipelineParallelismTest, LiteralTestUtil::ExpectR2Equal({{0, 0}, {1, 1}}, results[3]); } +// Naive implementation of pipeline parallelism: +// - 4 devices +// - 4 microbatches +// - no circular repeat +// - no disabled collectives +// - no collective pipelining +// +// Every stage of the pipeline is a single linear layer. +XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { + const absl::string_view kModuleStr = R"( + HloModule test + + get_circ_buffer_index { + offset = u32[] parameter(0) + index = u32[] parameter(1) + size = u32[] parameter(2) + t0 = u32[] add(offset, index) + t1 = u32[] divide(t0, size) + t2 = u32[] multiply(t1, size) + ROOT t4 = u32[] subtract(t0, t2) + } + + is_input_replica { + replica_id = u32[] replica-id() + c0 = u32[] constant(0) + ROOT predicate = pred[] compare(replica_id, c0), direction=EQ + } + + is_output_replica { + replica_id = u32[] replica-id() + c1 = u32[] constant(1) + ROOT predicate = pred[] compare(replica_id, c1), direction=EQ + } + + while_condition { + tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0) + i = u32[] get-tuple-element(tuple), index=4 + n = u32[] constant(7) + ROOT predicate = pred[] compare(i, n), direction=LT + } + + while_body { + tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0) + weights = f32[16,16] get-tuple-element(tuple), index=0 + input = f32[4,16] get-tuple-element(tuple), index=1 + output = f32[4,16] get-tuple-element(tuple), index=2 + tmp = f32[16] get-tuple-element(tuple), index=3 + i = u32[] get-tuple-element(tuple), index=4 + + c1 = u32[] constant(1) + c0 = u32[] constant(0) + c4 = u32[] constant(4) + + input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index + input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), + dynamic_slice_sizes={1,16} + input_slice_ = f32[16] reshape(input_slice) + + prev_stage_slice = f32[16] collective-permute(tmp), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + + read_input = pred[] call(), to_apply=is_input_replica + compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice) + + compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index + output_slice = f32[1,16] reshape(compute_out) + output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index, + c0) + + i_ = add(i, c1) + + ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple( + weights, input, output_, compute_out, i_) + } + + ENTRY main { + weights = f32[16,16] parameter(0) + input = f32[4,16] parameter(1) + + cf0 = f32[] constant(0) + output = f32[4,16] broadcast(cf0), dimensions={} + tmp = f32[16] broadcast(cf0), dimensions={} + c0 = u32[] constant(0) + + tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights, + input, output, tmp, c0) + tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple), + condition=while_condition, body=while_body + + ROOT output_ = f32[4,16] get-tuple-element(tuple_), index=2 + } + )"; + + const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // This pipeline consists of 4 layers, each of which is a single linear layer. + // We assign the weights to the replicas such that the layers scale the input + // data by 1.0, 2.0, 3.0 and 4.0. The combined effect is to scale the input + // data by 24.0. + auto generate_scale_weights = [&](float factor) -> absl::StatusOr { + return LiteralUtil::CreateLiteralWithGenerator( + ShapeUtil::MakeShape(F32, {16, 16}), + [&](absl::Span idx) -> float { + return idx[0] == idx[1] ? factor : 0.0; + }); + }; + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r0, generate_scale_weights(1.0)); + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r1, generate_scale_weights(2.0)); + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r2, generate_scale_weights(3.0)); + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r3, generate_scale_weights(4.0)); + + // Only the first replica holds the input to the pipeline in this naive + // implementation. The remaining replicas get zero/dummy input. + auto generate_zero_input = [&]() -> absl::StatusOr { + return LiteralUtil::CreateLiteralWithGenerator( + ShapeUtil::MakeShape(F32, {4, 16}), + [&](absl::Span idx) -> float { return 0.0; }); + }; + auto generate_fingerprint_input = [&]() -> absl::StatusOr { + return LiteralUtil::CreateLiteralWithGenerator( + ShapeUtil::MakeShape(F32, {4, 16}), + [&](absl::Span idx) -> float { + return 1.0 * idx[0] + 0.0001 * idx[1]; + }); + }; + TF_ASSERT_OK_AND_ASSIGN(Literal real_input, generate_fingerprint_input()); + TF_ASSERT_OK_AND_ASSIGN(Literal fake_input, generate_zero_input()); + + std::vector> args = {{&weights_r0, &real_input}, + {&weights_r1, &fake_input}, + {&weights_r2, &fake_input}, + {&weights_r3, &fake_input}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), args, kNumReplicas, + /*run_hlo_passes=*/true)); + + // Check pipeline output for last replica. + // The combined effect of the pipeline is to scale the input data by 24.0. + TF_ASSERT_OK_AND_ASSIGN( + Literal expected_output, + (LiteralUtil::CreateLiteralWithGenerator( + ShapeUtil::MakeShape(F32, {4, 16}), + [&](absl::Span multi_index) -> float { + return real_input.Get(multi_index) * 1.0 * 2.0 * 3.0 * 4.0; + }))); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], + ErrorSpec{1e-5, 1e-5})); +} + } // namespace } // namespace xla From cbb8764b6dfbedc00e8d663afa2689748d8b8b2d Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Thu, 25 Jul 2024 14:22:44 -0700 Subject: [PATCH 170/376] [XLA:GPU] Classify PartitionId as a noop. PiperOrigin-RevId: 656090585 --- xla/service/gpu/gpu_latency_hiding_scheduler.cc | 3 ++- xla/service/gpu/gpu_latency_hiding_scheduler_test.cc | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/xla/service/gpu/gpu_latency_hiding_scheduler.cc index fcd92f6799c9bd..51d52114f640b5 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -48,7 +48,8 @@ bool IsNopInstruction(const HloInstruction& hlo) { HloOpcode op = hlo.opcode(); return op == HloOpcode::kGetTupleElement || op == HloOpcode::kBitcast || op == HloOpcode::kConstant || op == HloOpcode::kParameter || - op == HloOpcode::kTuple || hlo.IsEffectiveBitcast(); + op == HloOpcode::kTuple || op == HloOpcode::kPartitionId || + op == HloOpcode::kReplicaId || hlo.IsEffectiveBitcast(); } bool IsAsyncComputeOp(const HloInstruction& hlo) { diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index 5e7f8a754e1aed..590adffffbd077 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -83,7 +83,9 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest, parameter1 = f32[32] parameter(1) const0 = f32[] constant(42) bitcast0 = f32[2,16] bitcast(parameter1) - tuple0 = (f32[], f32[2,16]) tuple(parameter0, bitcast0) + partition-id0 = u32[] partition-id() + replica-id0 = u32[] replica-id() + tuple0 = (f32[], f32[2,16], u32[], u32[]) tuple(parameter0, bitcast0, partition-id0, replica-id0) ROOT _ = get-tuple-element(tuple0), index=0 } )"; From ec4af9081fe389561b159872dfb14a34f44a4dc3 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Thu, 25 Jul 2024 14:39:06 -0700 Subject: [PATCH 171/376] Integrate LLVM at llvm/llvm-project@58fb51492d96 Updates LLVM usage to match [58fb51492d96](https://github.com/llvm/llvm-project/commit/58fb51492d96) PiperOrigin-RevId: 656097446 --- third_party/llvm/generated.patch | 657 +----------------- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/workspace.bzl | 4 +- .../triton/llvm_integration/series.bzl | 1 + .../tsl/third_party/llvm/generated.patch | 657 +----------------- .../tsl/third_party/llvm/workspace.bzl | 4 +- 6 files changed, 47 insertions(+), 1280 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index c7f7475c35588c..4eda7b241d21bc 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,638 +1,21 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst ---- a/clang/docs/ReleaseNotes.rst -+++ b/clang/docs/ReleaseNotes.rst -@@ -750,9 +750,6 @@ - - Clang now specifies that using ``auto`` in a lambda parameter is a C++14 extension when - appropriate. (`#46059: `_). - --- Clang now adds source file infomation for template instantiations as ``event["args"]["filename"]``. This -- added behind an option ``-ftime-trace-verbose``. This is expected to increase the size of trace by 2-3 times. -- - Improvements to Coverage Mapping - -------------------------------- - -diff -ruN --strip-trailing-cr a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td ---- a/clang/include/clang/Driver/Options.td -+++ b/clang/include/clang/Driver/Options.td -@@ -3998,10 +3998,6 @@ - HelpText<"Minimum time granularity (in microseconds) traced by time profiler">, - Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, - MarshallingInfoInt, "500u">; --def ftime_trace_verbose : Joined<["-"], "ftime-trace-verbose">, Group, -- HelpText<"Make time trace capture verbose event details (e.g. source filenames). This can increase the size of the output by 2-3 times">, -- Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, -- MarshallingInfoFlag>; - def ftime_trace_EQ : Joined<["-"], "ftime-trace=">, Group, - HelpText<"Similar to -ftime-trace. Specify the JSON file or a directory which will contain the JSON file">, - Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, -diff -ruN --strip-trailing-cr a/clang/include/clang/Frontend/FrontendOptions.h b/clang/include/clang/Frontend/FrontendOptions.h ---- a/clang/include/clang/Frontend/FrontendOptions.h -+++ b/clang/include/clang/Frontend/FrontendOptions.h -@@ -580,11 +580,6 @@ - /// Minimum time granularity (in microseconds) traced by time profiler. - unsigned TimeTraceGranularity; - -- /// Make time trace capture verbose event details (e.g. source filenames). -- /// This can increase the size of the output by 2-3 times. -- LLVM_PREFERRED_TYPE(bool) -- unsigned TimeTraceVerbose : 1; -- - /// Path which stores the output files for -ftime-trace - std::string TimeTracePath; - -@@ -606,8 +601,7 @@ - EmitSymbolGraph(false), EmitExtensionSymbolGraphs(false), - EmitSymbolGraphSymbolLabelsForTesting(false), - EmitPrettySymbolGraphs(false), GenReducedBMI(false), -- UseClangIRPipeline(false), TimeTraceGranularity(500), -- TimeTraceVerbose(false) {} -+ UseClangIRPipeline(false), TimeTraceGranularity(500) {} - - /// getInputKindForExtension - Return the appropriate input kind for a file - /// extension. For example, "c" would return Language::C. -diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp ---- a/clang/lib/Driver/ToolChains/Clang.cpp -+++ b/clang/lib/Driver/ToolChains/Clang.cpp -@@ -6757,7 +6757,6 @@ - if (const char *Name = C.getTimeTraceFile(&JA)) { - CmdArgs.push_back(Args.MakeArgString("-ftime-trace=" + Twine(Name))); - Args.AddLastArg(CmdArgs, options::OPT_ftime_trace_granularity_EQ); -- Args.AddLastArg(CmdArgs, options::OPT_ftime_trace_verbose); - } - - if (Arg *A = Args.getLastArg(options::OPT_ftrapv_handler_EQ)) { -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp ---- a/clang/lib/Sema/SemaTemplateInstantiate.cpp -+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp -@@ -3426,16 +3426,11 @@ - return true; - - llvm::TimeTraceScope TimeScope("InstantiateClass", [&]() { -- llvm::TimeTraceMetadata M; -- llvm::raw_string_ostream OS(M.Detail); -+ std::string Name; -+ llvm::raw_string_ostream OS(Name); - Instantiation->getNameForDiagnostic(OS, getPrintingPolicy(), - /*Qualified=*/true); -- if (llvm::isTimeTraceVerbose()) { -- auto Loc = SourceMgr.getExpansionLoc(Instantiation->getLocation()); -- M.File = SourceMgr.getFilename(Loc); -- M.Line = SourceMgr.getExpansionLineNumber(Loc); -- } -- return M; -+ return Name; - }); - - Pattern = PatternDef; -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp ---- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp -+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp -@@ -4966,16 +4966,11 @@ - } - - llvm::TimeTraceScope TimeScope("InstantiateFunction", [&]() { -- llvm::TimeTraceMetadata M; -- llvm::raw_string_ostream OS(M.Detail); -+ std::string Name; -+ llvm::raw_string_ostream OS(Name); - Function->getNameForDiagnostic(OS, getPrintingPolicy(), - /*Qualified=*/true); -- if (llvm::isTimeTraceVerbose()) { -- auto Loc = SourceMgr.getExpansionLoc(Function->getLocation()); -- M.File = SourceMgr.getFilename(Loc); -- M.Line = SourceMgr.getExpansionLineNumber(Loc); -- } -- return M; -+ return Name; - }); - - // If we're performing recursive template instantiation, create our own -diff -ruN --strip-trailing-cr a/clang/test/Driver/ftime-trace-sections.cpp b/clang/test/Driver/ftime-trace-sections.cpp ---- a/clang/test/Driver/ftime-trace-sections.cpp -+++ b/clang/test/Driver/ftime-trace-sections.cpp -@@ -1,5 +1,5 @@ - // RUN: rm -rf %t && mkdir %t && cd %t --// RUN: %clangxx -S -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -ftime-trace -ftime-trace-granularity=0 -o out %s - // RUN: %python %S/ftime-trace-sections.py < out.json - - template -diff -ruN --strip-trailing-cr a/clang/test/Driver/ftime-trace.cpp b/clang/test/Driver/ftime-trace.cpp ---- a/clang/test/Driver/ftime-trace.cpp -+++ b/clang/test/Driver/ftime-trace.cpp -@@ -1,18 +1,18 @@ - // RUN: rm -rf %t && mkdir -p %t && cd %t --// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace -ftime-trace-granularity=0 -o out %s - // RUN: cat out.json \ - // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ - // RUN: | FileCheck %s --// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=new-name.json -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=new-name.json -ftime-trace-granularity=0 -o out %s - // RUN: cat new-name.json \ - // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ - // RUN: | FileCheck %s - // RUN: mkdir dir1 dir2 --// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir1 -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir1 -ftime-trace-granularity=0 -o out %s - // RUN: cat dir1/out.json \ - // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ - // RUN: | FileCheck %s --// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir2/ -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir2/ -ftime-trace-granularity=0 -o out %s - // RUN: cat dir2/out.json \ - // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ - // RUN: | FileCheck %s -@@ -34,33 +34,32 @@ - // RUN: mkdir d e f && cp %s d/a.cpp && touch d/b.c - - /// TODO: Support -fno-integrated-as. --// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -fintegrated-as d/a.cpp -o e/a.o 2>&1 | FileCheck %s --check-prefix=COMPILE1 --// COMPILE1: -cc1{{.*}} "-ftime-trace=e/a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -+// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -fintegrated-as d/a.cpp -o e/a.o 2>&1 | FileCheck %s --check-prefix=COMPILE1 -+// COMPILE1: -cc1{{.*}} "-ftime-trace=e/a.json" "-ftime-trace-granularity=0" - --// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=COMPILE2 --// COMPILE2: -cc1{{.*}} "-ftime-trace=f/a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" --// COMPILE2: -cc1{{.*}} "-ftime-trace=f/b.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -+// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=COMPILE2 -+// COMPILE2: -cc1{{.*}} "-ftime-trace=f/a.json" "-ftime-trace-granularity=0" -+// COMPILE2: -cc1{{.*}} "-ftime-trace=f/b.json" "-ftime-trace-granularity=0" - - /// -o specifies the link output. Create ${output}-${basename}.json. --// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o e/x 2>&1 | FileCheck %s --check-prefix=LINK1 --// LINK1: -cc1{{.*}} "-ftime-trace=e/x-a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" --// LINK1: -cc1{{.*}} "-ftime-trace=e/x-b.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -+// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -o e/x 2>&1 | FileCheck %s --check-prefix=LINK1 -+// LINK1: -cc1{{.*}} "-ftime-trace=e/x-a.json" "-ftime-trace-granularity=0" -+// LINK1: -cc1{{.*}} "-ftime-trace=e/x-b.json" "-ftime-trace-granularity=0" - - /// -dumpdir is f/g, not ending with a path separator. We create f/g${basename}.json. --// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o e/x -dumpdir f/g 2>&1 | FileCheck %s --check-prefix=LINK2 --// LINK2: -cc1{{.*}} "-ftime-trace=f/ga.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" --// LINK2: -cc1{{.*}} "-ftime-trace=f/gb.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -- --// RUN: %clang -### -ftime-trace=e -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o f/x -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=LINK3 --// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}a-{{[^.]*}}.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" --// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}b-{{[^.]*}}.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -+// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -o e/x -dumpdir f/g 2>&1 | FileCheck %s --check-prefix=LINK2 -+// LINK2: -cc1{{.*}} "-ftime-trace=f/ga.json" "-ftime-trace-granularity=0" -+// LINK2: -cc1{{.*}} "-ftime-trace=f/gb.json" "-ftime-trace-granularity=0" -+ -+// RUN: %clang -### -ftime-trace=e -ftime-trace-granularity=0 d/a.cpp d/b.c -o f/x -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=LINK3 -+// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}a-{{[^.]*}}.json" "-ftime-trace-granularity=0" -+// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}b-{{[^.]*}}.json" "-ftime-trace-granularity=0" - --// RUN: %clang -### -ftime-trace -ftime-trace=e -ftime-trace-granularity=1 -ftime-trace-verbose -xassembler d/a.cpp 2>&1 | \ -+// RUN: %clang -### -ftime-trace -ftime-trace=e -ftime-trace-granularity=1 -xassembler d/a.cpp 2>&1 | \ - // RUN: FileCheck %s --check-prefix=UNUSED - // UNUSED: warning: argument unused during compilation: '-ftime-trace' - // UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace=e' - // UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace-granularity=1' --// UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace-verbose' - // UNUSED-NOT: warning: - - template -diff -ruN --strip-trailing-cr a/clang/tools/driver/cc1_main.cpp b/clang/tools/driver/cc1_main.cpp ---- a/clang/tools/driver/cc1_main.cpp -+++ b/clang/tools/driver/cc1_main.cpp -@@ -241,8 +241,7 @@ - - if (!Clang->getFrontendOpts().TimeTracePath.empty()) { - llvm::timeTraceProfilerInitialize( -- Clang->getFrontendOpts().TimeTraceGranularity, Argv0, -- Clang->getFrontendOpts().TimeTraceVerbose); -+ Clang->getFrontendOpts().TimeTraceGranularity, Argv0); - } - // --print-supported-cpus takes priority over the actual compilation. - if (Clang->getFrontendOpts().PrintSupportedCPUs) -diff -ruN --strip-trailing-cr a/clang/unittests/Support/TimeProfilerTest.cpp b/clang/unittests/Support/TimeProfilerTest.cpp ---- a/clang/unittests/Support/TimeProfilerTest.cpp -+++ b/clang/unittests/Support/TimeProfilerTest.cpp -@@ -10,15 +10,11 @@ - #include "clang/Frontend/FrontendActions.h" - #include "clang/Lex/PreprocessorOptions.h" - --#include "llvm/ADT/StringMap.h" - #include "llvm/Support/JSON.h" --#include "llvm/Support/Path.h" - #include "llvm/Support/TimeProfiler.h" --#include "llvm/Support/VirtualFileSystem.h" - #include - - #include "gtest/gtest.h" --#include - - using namespace clang; - using namespace llvm; -@@ -27,8 +23,7 @@ - - // Should be called before testing. - void setupProfiler() { -- timeTraceProfilerInitialize(/*TimeTraceGranularity=*/0, "test", -- /*TimeTraceVerbose=*/true); -+ timeTraceProfilerInitialize(/*TimeTraceGranularity=*/0, "test"); - } - - // Should be called after `compileFromString()`. -@@ -43,24 +38,14 @@ - - // Returns true if code compiles successfully. - // We only parse AST here. This is enough for constexpr evaluation. --bool compileFromString(StringRef Code, StringRef Standard, StringRef File, -- llvm::StringMap Headers = {}) { -+bool compileFromString(StringRef Code, StringRef Standard, StringRef FileName) { - CompilerInstance Compiler; - Compiler.createDiagnostics(); - -- llvm::IntrusiveRefCntPtr FS( -- new llvm::vfs::InMemoryFileSystem()); -- FS->addFile(File, 0, MemoryBuffer::getMemBuffer(Code)); -- for (const auto &Header : Headers) { -- FS->addFile(Header.getKey(), 0, -- MemoryBuffer::getMemBuffer(Header.getValue())); -- } -- llvm::IntrusiveRefCntPtr Files( -- new FileManager(FileSystemOptions(), FS)); -- Compiler.setFileManager(Files.get()); -- - auto Invocation = std::make_shared(); -- std::vector Args = {Standard.data(), File.data()}; -+ Invocation->getPreprocessorOpts().addRemappedFile( -+ FileName, MemoryBuffer::getMemBuffer(Code).release()); -+ const char *Args[] = {Standard.data(), FileName.data()}; - CompilerInvocation::CreateFromArgs(*Invocation, Args, - Compiler.getDiagnostics()); - Compiler.setInvocation(std::move(Invocation)); -@@ -75,28 +60,13 @@ - return Compiler.ExecuteAction(Action); - } - --std::string GetMetadata(json::Object *Event) { -- std::string Metadata; -- llvm::raw_string_ostream OS(Metadata); -- if (json::Object *Args = Event->getObject("args")) { -- if (auto Detail = Args->getString("detail")) -- OS << Detail; -- // Use only filename to not include os-specific path separators. -- if (auto File = Args->getString("file")) -- OS << ", " << llvm::sys::path::filename(*File); -- if (auto Line = Args->getInteger("line")) -- OS << ":" << *Line; -- } -- return Metadata; --} -- - // Returns pretty-printed trace graph. - std::string buildTraceGraph(StringRef Json) { - struct EventRecord { - int64_t TimestampBegin; - int64_t TimestampEnd; -- std::string Name; -- std::string Metadata; -+ StringRef Name; -+ StringRef Detail; - }; - std::vector Events; - -@@ -111,13 +81,10 @@ - int64_t TimestampBegin = TraceEventObj->getInteger("ts").value_or(0); - int64_t TimestampEnd = - TimestampBegin + TraceEventObj->getInteger("dur").value_or(0); -- std::string Name = TraceEventObj->getString("name").value_or("").str(); -- std::string Metadata = GetMetadata(TraceEventObj); -- -- // Source events are asynchronous events and may not perfectly nest the -- // synchronous events. Skip testing them. -- if (Name == "Source") -- continue; -+ StringRef Name = TraceEventObj->getString("name").value_or(""); -+ StringRef Detail = ""; -+ if (json::Object *Args = TraceEventObj->getObject("args")) -+ Detail = Args->getString("detail").value_or(""); - - // This is a "summary" event, like "Total PerformPendingInstantiations", - // skip it -@@ -125,7 +92,7 @@ - continue; - - Events.emplace_back( -- EventRecord{TimestampBegin, TimestampEnd, Name, Metadata}); -+ EventRecord{TimestampBegin, TimestampEnd, Name, Detail}); - } - - // There can be nested events that are very fast, for example: -@@ -165,9 +132,9 @@ - Stream << "| "; - } - Stream.write(Event.Name.data(), Event.Name.size()); -- if (!Event.Metadata.empty()) { -+ if (!Event.Detail.empty()) { - Stream << " ("; -- Stream.write(Event.Metadata.data(), Event.Metadata.size()); -+ Stream.write(Event.Detail.data(), Event.Detail.size()); - Stream << ")"; - } - Stream << "\n"; -@@ -178,7 +145,7 @@ - } // namespace - - TEST(TimeProfilerTest, ConstantEvaluationCxx20) { -- std::string Code = R"( -+ constexpr StringRef Code = R"( - void print(double value); - - namespace slow_namespace { -@@ -208,7 +175,8 @@ - setupProfiler(); - ASSERT_TRUE(compileFromString(Code, "-std=c++20", "test.cc")); - std::string Json = teardownProfiler(); -- ASSERT_EQ(R"( -+ std::string TraceGraph = buildTraceGraph(Json); -+ ASSERT_TRUE(TraceGraph == R"( - Frontend - | ParseDeclarationOrFunctionDefinition (test.cc:2:1) - | ParseDeclarationOrFunctionDefinition (test.cc:6:1) -@@ -234,54 +202,14 @@ - | ParseDeclarationOrFunctionDefinition (test.cc:25:1) - | | EvaluateAsInitializer (slow_init_list) - | PerformPendingInstantiations --)", -- buildTraceGraph(Json)); --} -- --TEST(TimeProfilerTest, TemplateInstantiations) { -- std::string B_H = R"( -- template -- T fooB(T t) { -- return T(); -- } -+)"); - -- #define MacroTemp(x) template void foo##x(T) { T(); } -- )"; -- -- std::string A_H = R"( -- #include "b.h" -- -- MacroTemp(MTA) -- -- template -- void fooA(T t) { fooB(t); fooMTA(t); } -- )"; -- std::string Code = R"( -- #include "a.h" -- void user() { fooA(0); } -- )"; -- -- setupProfiler(); -- ASSERT_TRUE(compileFromString(Code, "-std=c++20", "test.cc", -- /*Headers=*/{{"a.h", A_H}, {"b.h", B_H}})); -- std::string Json = teardownProfiler(); -- ASSERT_EQ(R"( --Frontend --| ParseFunctionDefinition (fooB) --| ParseFunctionDefinition (fooMTA) --| ParseFunctionDefinition (fooA) --| ParseDeclarationOrFunctionDefinition (test.cc:3:5) --| | ParseFunctionDefinition (user) --| PerformPendingInstantiations --| | InstantiateFunction (fooA, a.h:7) --| | | InstantiateFunction (fooB, b.h:3) --| | | InstantiateFunction (fooMTA, a.h:4) --)", -- buildTraceGraph(Json)); -+ // NOTE: If this test is failing, run this test with -+ // `llvm::errs() << TraceGraph;` and change the assert above. - } - - TEST(TimeProfilerTest, ConstantEvaluationC99) { -- std::string Code = R"( -+ constexpr StringRef Code = R"( - struct { - short quantval[4]; // 3rd line - } value; -@@ -290,12 +218,15 @@ - setupProfiler(); - ASSERT_TRUE(compileFromString(Code, "-std=c99", "test.c")); - std::string Json = teardownProfiler(); -- ASSERT_EQ(R"( -+ std::string TraceGraph = buildTraceGraph(Json); -+ ASSERT_TRUE(TraceGraph == R"( - Frontend - | ParseDeclarationOrFunctionDefinition (test.c:2:1) - | | isIntegerConstantExpr () - | | EvaluateKnownConstIntCheckOverflow () - | PerformPendingInstantiations --)", -- buildTraceGraph(Json)); -+)"); -+ -+ // NOTE: If this test is failing, run this test with -+ // `llvm::errs() << TraceGraph;` and change the assert above. - } -diff -ruN --strip-trailing-cr a/lld/test/MachO/reproduce-thin-archive-objc.s b/lld/test/MachO/reproduce-thin-archive-objc.s ---- a/lld/test/MachO/reproduce-thin-archive-objc.s -+++ b/lld/test/MachO/reproduce-thin-archive-objc.s -@@ -4,20 +4,19 @@ - ## during linking. However, we need to iterate over all members for -ObjC, check that we don't - ## crash when we encounter a missing member. - --# RUN: rm -rf %t; mkdir %t --# RUN: sed s/SYM/_main/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o %t/main.o --# RUN: sed s/SYM/_unused/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o %t/unused.o -+# RUN: rm -rf %t && mkdir %t && cd %t -+# RUN: sed s/SYM/_main/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o main.o -+# RUN: sed s/SYM/_unused/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o unused.o - --# RUN: cd %t; llvm-ar rcsT unused.a unused.o; rm unused.o -+# RUN: llvm-ar rcsT unused.a unused.o; rm unused.o - ## FIXME: Absolute paths don't end up relativized in the repro file. - - # RUN: %no-fatal-warnings-lld %t/main.o %t/unused.a -ObjC -o /dev/null 2>&1 \ - # RUN: | FileCheck %s --check-prefix=WARN - --# RUN: %lld %t/main.o %t/unused.a -ObjC --no-warn-thin-archive-missing-members -o /dev/null \ --# RUN: | FileCheck %s --implicit-check-not 'warning' --allow-empty -+# RUN: %lld main.o unused.a -ObjC --no-warn-thin-archive-missing-members 2>&1 | count 0 - --# WARN: ld64.lld: warning: {{.*}}unused.a: -ObjC failed to open archive member: 'unused.o' -+# WARN: warning: {{.*}}unused.a: -ObjC failed to open archive member: 'unused.o' - - .text - .globl SYM -diff -ruN --strip-trailing-cr a/llvm/include/llvm/Support/TimeProfiler.h b/llvm/include/llvm/Support/TimeProfiler.h ---- a/llvm/include/llvm/Support/TimeProfiler.h -+++ b/llvm/include/llvm/Support/TimeProfiler.h -@@ -83,28 +83,16 @@ - - class raw_pwrite_stream; - --struct TimeTraceMetadata { -- std::string Detail; -- // Source file and line number information for the event. -- std::string File; -- int Line; -- -- bool isEmpty() const { return Detail.empty() && File.empty(); } --}; -- - struct TimeTraceProfiler; - TimeTraceProfiler *getTimeTraceProfilerInstance(); - --bool isTimeTraceVerbose(); -- - struct TimeTraceProfilerEntry; - - /// Initialize the time trace profiler. - /// This sets up the global \p TimeTraceProfilerInstance - /// variable to be the profiler instance. - void timeTraceProfilerInitialize(unsigned TimeTraceGranularity, -- StringRef ProcName, -- bool TimeTraceVerbose = false); -+ StringRef ProcName); - - /// Cleanup the time trace profiler, if it was initialized. - void timeTraceProfilerCleanup(); -@@ -140,10 +128,6 @@ - timeTraceProfilerBegin(StringRef Name, - llvm::function_ref Detail); - --TimeTraceProfilerEntry * --timeTraceProfilerBegin(StringRef Name, -- llvm::function_ref MetaData); -- - /// Manually begin a time section, with the given \p Name and \p Detail. - /// This starts Async Events having \p Name as a category which is shown - /// separately from other traces. See -@@ -180,11 +164,6 @@ - if (getTimeTraceProfilerInstance() != nullptr) - Entry = timeTraceProfilerBegin(Name, Detail); - } -- TimeTraceScope(StringRef Name, -- llvm::function_ref Metadata) { -- if (getTimeTraceProfilerInstance() != nullptr) -- Entry = timeTraceProfilerBegin(Name, Metadata); -- } - ~TimeTraceScope() { - if (getTimeTraceProfilerInstance() != nullptr) - timeTraceProfilerEnd(Entry); -diff -ruN --strip-trailing-cr a/llvm/lib/Support/TimeProfiler.cpp b/llvm/lib/Support/TimeProfiler.cpp ---- a/llvm/lib/Support/TimeProfiler.cpp -+++ b/llvm/lib/Support/TimeProfiler.cpp -@@ -73,20 +73,12 @@ - const TimePointType Start; - TimePointType End; - const std::string Name; -- TimeTraceMetadata Metadata; -- -+ const std::string Detail; - const bool AsyncEvent = false; - TimeTraceProfilerEntry(TimePointType &&S, TimePointType &&E, std::string &&N, - std::string &&Dt, bool Ae) -- : Start(std::move(S)), End(std::move(E)), Name(std::move(N)), Metadata(), -- AsyncEvent(Ae) { -- Metadata.Detail = std::move(Dt); -- } -- -- TimeTraceProfilerEntry(TimePointType &&S, TimePointType &&E, std::string &&N, -- TimeTraceMetadata &&Mt, bool Ae) - : Start(std::move(S)), End(std::move(E)), Name(std::move(N)), -- Metadata(std::move(Mt)), AsyncEvent(Ae) {} -+ Detail(std::move(Dt)), AsyncEvent(Ae) {} - - // Calculate timings for FlameGraph. Cast time points to microsecond precision - // rather than casting duration. This avoids truncation issues causing inner -@@ -105,12 +97,10 @@ - }; - - struct llvm::TimeTraceProfiler { -- TimeTraceProfiler(unsigned TimeTraceGranularity = 0, StringRef ProcName = "", -- bool TimeTraceVerbose = false) -+ TimeTraceProfiler(unsigned TimeTraceGranularity = 0, StringRef ProcName = "") - : BeginningOfTime(system_clock::now()), StartTime(ClockType::now()), - ProcName(ProcName), Pid(sys::Process::getProcessId()), -- Tid(llvm::get_threadid()), TimeTraceGranularity(TimeTraceGranularity), -- TimeTraceVerbose(TimeTraceVerbose) { -+ Tid(llvm::get_threadid()), TimeTraceGranularity(TimeTraceGranularity) { - llvm::get_thread_name(ThreadName); - } - -@@ -123,15 +113,6 @@ - return Stack.back().get(); - } - -- TimeTraceProfilerEntry * -- begin(std::string Name, llvm::function_ref Metadata, -- bool AsyncEvent = false) { -- Stack.emplace_back(std::make_unique( -- ClockType::now(), TimePointType(), std::move(Name), Metadata(), -- AsyncEvent)); -- return Stack.back().get(); -- } -- - void end() { - assert(!Stack.empty() && "Must call begin() first"); - end(*Stack.back()); -@@ -203,15 +184,8 @@ - J.attribute("dur", DurUs); - } - J.attribute("name", E.Name); -- if (!E.Metadata.isEmpty()) { -- J.attributeObject("args", [&] { -- if (!E.Metadata.Detail.empty()) -- J.attribute("detail", E.Metadata.Detail); -- if (!E.Metadata.File.empty()) -- J.attribute("file", E.Metadata.File); -- if (E.Metadata.Line > 0) -- J.attribute("line", E.Metadata.Line); -- }); -+ if (!E.Detail.empty()) { -+ J.attributeObject("args", [&] { J.attribute("detail", E.Detail); }); - } - }); - -@@ -333,25 +307,14 @@ - - // Minimum time granularity (in microseconds) - const unsigned TimeTraceGranularity; -- -- // Make time trace capture verbose event details (e.g. source filenames). This -- // can increase the size of the output by 2-3 times. -- const bool TimeTraceVerbose; - }; - --bool llvm::isTimeTraceVerbose() { -- return getTimeTraceProfilerInstance() && -- getTimeTraceProfilerInstance()->TimeTraceVerbose; --} -- - void llvm::timeTraceProfilerInitialize(unsigned TimeTraceGranularity, -- StringRef ProcName, -- bool TimeTraceVerbose) { -+ StringRef ProcName) { - assert(TimeTraceProfilerInstance == nullptr && - "Profiler should not be initialized"); - TimeTraceProfilerInstance = new TimeTraceProfiler( -- TimeTraceGranularity, llvm::sys::path::filename(ProcName), -- TimeTraceVerbose); -+ TimeTraceGranularity, llvm::sys::path::filename(ProcName)); - } - - // Removes all TimeTraceProfilerInstances. -@@ -418,14 +381,6 @@ - return nullptr; - } - --TimeTraceProfilerEntry * --llvm::timeTraceProfilerBegin(StringRef Name, -- llvm::function_ref Metadata) { -- if (TimeTraceProfilerInstance != nullptr) -- return TimeTraceProfilerInstance->begin(std::string(Name), Metadata, false); -- return nullptr; --} -- - TimeTraceProfilerEntry *llvm::timeTraceAsyncProfilerBegin(StringRef Name, - StringRef Detail) { - if (TimeTraceProfilerInstance != nullptr) +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +@@ -675,6 +675,7 @@ + deps = [ + ":__support_common", + ":__support_cpp_type_traits", ++ ":__support_fputil_dyadic_float", + ":__support_fputil_fenv_impl", + ":__support_fputil_fp_bits", + ":__support_macros_optimization", +@@ -1089,7 +1090,7 @@ + ":__support_macros_optimization", + ":__support_osutil_syscall", + ":types_pid_t", +- ] ++ ], + ) + + libc_support_library( diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 2949b73a155af1..a6b1b06abe37c5 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "84658fb82b67fc22ecba1560d0cddd09f9104178" - LLVM_SHA256 = "b4a50d36a8ab0284f7022f61bbf07a2fb3ea25c6bb2cc422d2418c23b61366da" + LLVM_COMMIT = "58fb51492d9669525662fa269295d85537968569" + LLVM_SHA256 = "f6cac3f3f562a7bd3a36a828df2960a1ebc2cd6237f4cb95a66f1bd16e918ef9" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index ff323785844790..01e4a8bf6970a8 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "d889df1c54b8cd02d90a44aff7bd485340b4774d" - SHARDY_SHA256 = "5a6a83cbae22dfe0940825da944d48ef0968ff7f74ee38ea2d32b19443a10d8c" + SHARDY_COMMIT = "f7ba97a90be022a20dc0f970998bc0855f152314" + SHARDY_SHA256 = "6dcf7672c93ed22fa676ab8d33e4d5b64eff6cee4668098f0937fb57cb8f1320" tf_http_archive( name = "shardy", diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 656b9c894904d8..9d0e1204ba527f 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,5 +8,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ + "//third_party/triton/llvm_integration:cl656020169.patch", # Add new patches just above this line ] diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index c7f7475c35588c..4eda7b241d21bc 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1,638 +1,21 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst ---- a/clang/docs/ReleaseNotes.rst -+++ b/clang/docs/ReleaseNotes.rst -@@ -750,9 +750,6 @@ - - Clang now specifies that using ``auto`` in a lambda parameter is a C++14 extension when - appropriate. (`#46059: `_). - --- Clang now adds source file infomation for template instantiations as ``event["args"]["filename"]``. This -- added behind an option ``-ftime-trace-verbose``. This is expected to increase the size of trace by 2-3 times. -- - Improvements to Coverage Mapping - -------------------------------- - -diff -ruN --strip-trailing-cr a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td ---- a/clang/include/clang/Driver/Options.td -+++ b/clang/include/clang/Driver/Options.td -@@ -3998,10 +3998,6 @@ - HelpText<"Minimum time granularity (in microseconds) traced by time profiler">, - Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, - MarshallingInfoInt, "500u">; --def ftime_trace_verbose : Joined<["-"], "ftime-trace-verbose">, Group, -- HelpText<"Make time trace capture verbose event details (e.g. source filenames). This can increase the size of the output by 2-3 times">, -- Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, -- MarshallingInfoFlag>; - def ftime_trace_EQ : Joined<["-"], "ftime-trace=">, Group, - HelpText<"Similar to -ftime-trace. Specify the JSON file or a directory which will contain the JSON file">, - Visibility<[ClangOption, CC1Option, CLOption, DXCOption]>, -diff -ruN --strip-trailing-cr a/clang/include/clang/Frontend/FrontendOptions.h b/clang/include/clang/Frontend/FrontendOptions.h ---- a/clang/include/clang/Frontend/FrontendOptions.h -+++ b/clang/include/clang/Frontend/FrontendOptions.h -@@ -580,11 +580,6 @@ - /// Minimum time granularity (in microseconds) traced by time profiler. - unsigned TimeTraceGranularity; - -- /// Make time trace capture verbose event details (e.g. source filenames). -- /// This can increase the size of the output by 2-3 times. -- LLVM_PREFERRED_TYPE(bool) -- unsigned TimeTraceVerbose : 1; -- - /// Path which stores the output files for -ftime-trace - std::string TimeTracePath; - -@@ -606,8 +601,7 @@ - EmitSymbolGraph(false), EmitExtensionSymbolGraphs(false), - EmitSymbolGraphSymbolLabelsForTesting(false), - EmitPrettySymbolGraphs(false), GenReducedBMI(false), -- UseClangIRPipeline(false), TimeTraceGranularity(500), -- TimeTraceVerbose(false) {} -+ UseClangIRPipeline(false), TimeTraceGranularity(500) {} - - /// getInputKindForExtension - Return the appropriate input kind for a file - /// extension. For example, "c" would return Language::C. -diff -ruN --strip-trailing-cr a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp ---- a/clang/lib/Driver/ToolChains/Clang.cpp -+++ b/clang/lib/Driver/ToolChains/Clang.cpp -@@ -6757,7 +6757,6 @@ - if (const char *Name = C.getTimeTraceFile(&JA)) { - CmdArgs.push_back(Args.MakeArgString("-ftime-trace=" + Twine(Name))); - Args.AddLastArg(CmdArgs, options::OPT_ftime_trace_granularity_EQ); -- Args.AddLastArg(CmdArgs, options::OPT_ftime_trace_verbose); - } - - if (Arg *A = Args.getLastArg(options::OPT_ftrapv_handler_EQ)) { -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp ---- a/clang/lib/Sema/SemaTemplateInstantiate.cpp -+++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp -@@ -3426,16 +3426,11 @@ - return true; - - llvm::TimeTraceScope TimeScope("InstantiateClass", [&]() { -- llvm::TimeTraceMetadata M; -- llvm::raw_string_ostream OS(M.Detail); -+ std::string Name; -+ llvm::raw_string_ostream OS(Name); - Instantiation->getNameForDiagnostic(OS, getPrintingPolicy(), - /*Qualified=*/true); -- if (llvm::isTimeTraceVerbose()) { -- auto Loc = SourceMgr.getExpansionLoc(Instantiation->getLocation()); -- M.File = SourceMgr.getFilename(Loc); -- M.Line = SourceMgr.getExpansionLineNumber(Loc); -- } -- return M; -+ return Name; - }); - - Pattern = PatternDef; -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp ---- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp -+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp -@@ -4966,16 +4966,11 @@ - } - - llvm::TimeTraceScope TimeScope("InstantiateFunction", [&]() { -- llvm::TimeTraceMetadata M; -- llvm::raw_string_ostream OS(M.Detail); -+ std::string Name; -+ llvm::raw_string_ostream OS(Name); - Function->getNameForDiagnostic(OS, getPrintingPolicy(), - /*Qualified=*/true); -- if (llvm::isTimeTraceVerbose()) { -- auto Loc = SourceMgr.getExpansionLoc(Function->getLocation()); -- M.File = SourceMgr.getFilename(Loc); -- M.Line = SourceMgr.getExpansionLineNumber(Loc); -- } -- return M; -+ return Name; - }); - - // If we're performing recursive template instantiation, create our own -diff -ruN --strip-trailing-cr a/clang/test/Driver/ftime-trace-sections.cpp b/clang/test/Driver/ftime-trace-sections.cpp ---- a/clang/test/Driver/ftime-trace-sections.cpp -+++ b/clang/test/Driver/ftime-trace-sections.cpp -@@ -1,5 +1,5 @@ - // RUN: rm -rf %t && mkdir %t && cd %t --// RUN: %clangxx -S -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -ftime-trace -ftime-trace-granularity=0 -o out %s - // RUN: %python %S/ftime-trace-sections.py < out.json - - template -diff -ruN --strip-trailing-cr a/clang/test/Driver/ftime-trace.cpp b/clang/test/Driver/ftime-trace.cpp ---- a/clang/test/Driver/ftime-trace.cpp -+++ b/clang/test/Driver/ftime-trace.cpp -@@ -1,18 +1,18 @@ - // RUN: rm -rf %t && mkdir -p %t && cd %t --// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace -ftime-trace-granularity=0 -o out %s - // RUN: cat out.json \ - // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ - // RUN: | FileCheck %s --// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=new-name.json -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=new-name.json -ftime-trace-granularity=0 -o out %s - // RUN: cat new-name.json \ - // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ - // RUN: | FileCheck %s - // RUN: mkdir dir1 dir2 --// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir1 -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir1 -ftime-trace-granularity=0 -o out %s - // RUN: cat dir1/out.json \ - // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ - // RUN: | FileCheck %s --// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir2/ -ftime-trace-granularity=0 -ftime-trace-verbose -o out %s -+// RUN: %clangxx -S -no-canonical-prefixes -ftime-trace=dir2/ -ftime-trace-granularity=0 -o out %s - // RUN: cat dir2/out.json \ - // RUN: | %python -c 'import json, sys; json.dump(json.loads(sys.stdin.read()), sys.stdout, sort_keys=True, indent=2)' \ - // RUN: | FileCheck %s -@@ -34,33 +34,32 @@ - // RUN: mkdir d e f && cp %s d/a.cpp && touch d/b.c - - /// TODO: Support -fno-integrated-as. --// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose -fintegrated-as d/a.cpp -o e/a.o 2>&1 | FileCheck %s --check-prefix=COMPILE1 --// COMPILE1: -cc1{{.*}} "-ftime-trace=e/a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -+// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -fintegrated-as d/a.cpp -o e/a.o 2>&1 | FileCheck %s --check-prefix=COMPILE1 -+// COMPILE1: -cc1{{.*}} "-ftime-trace=e/a.json" "-ftime-trace-granularity=0" - --// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=COMPILE2 --// COMPILE2: -cc1{{.*}} "-ftime-trace=f/a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" --// COMPILE2: -cc1{{.*}} "-ftime-trace=f/b.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -+// RUN: %clang -### -c -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=COMPILE2 -+// COMPILE2: -cc1{{.*}} "-ftime-trace=f/a.json" "-ftime-trace-granularity=0" -+// COMPILE2: -cc1{{.*}} "-ftime-trace=f/b.json" "-ftime-trace-granularity=0" - - /// -o specifies the link output. Create ${output}-${basename}.json. --// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o e/x 2>&1 | FileCheck %s --check-prefix=LINK1 --// LINK1: -cc1{{.*}} "-ftime-trace=e/x-a.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" --// LINK1: -cc1{{.*}} "-ftime-trace=e/x-b.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -+// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -o e/x 2>&1 | FileCheck %s --check-prefix=LINK1 -+// LINK1: -cc1{{.*}} "-ftime-trace=e/x-a.json" "-ftime-trace-granularity=0" -+// LINK1: -cc1{{.*}} "-ftime-trace=e/x-b.json" "-ftime-trace-granularity=0" - - /// -dumpdir is f/g, not ending with a path separator. We create f/g${basename}.json. --// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o e/x -dumpdir f/g 2>&1 | FileCheck %s --check-prefix=LINK2 --// LINK2: -cc1{{.*}} "-ftime-trace=f/ga.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" --// LINK2: -cc1{{.*}} "-ftime-trace=f/gb.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -- --// RUN: %clang -### -ftime-trace=e -ftime-trace-granularity=0 -ftime-trace-verbose d/a.cpp d/b.c -o f/x -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=LINK3 --// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}a-{{[^.]*}}.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" --// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}b-{{[^.]*}}.json" "-ftime-trace-granularity=0" "-ftime-trace-verbose" -+// RUN: %clang -### -ftime-trace -ftime-trace-granularity=0 d/a.cpp d/b.c -o e/x -dumpdir f/g 2>&1 | FileCheck %s --check-prefix=LINK2 -+// LINK2: -cc1{{.*}} "-ftime-trace=f/ga.json" "-ftime-trace-granularity=0" -+// LINK2: -cc1{{.*}} "-ftime-trace=f/gb.json" "-ftime-trace-granularity=0" -+ -+// RUN: %clang -### -ftime-trace=e -ftime-trace-granularity=0 d/a.cpp d/b.c -o f/x -dumpdir f/ 2>&1 | FileCheck %s --check-prefix=LINK3 -+// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}a-{{[^.]*}}.json" "-ftime-trace-granularity=0" -+// LINK3: -cc1{{.*}} "-ftime-trace=e{{/|\\\\}}b-{{[^.]*}}.json" "-ftime-trace-granularity=0" - --// RUN: %clang -### -ftime-trace -ftime-trace=e -ftime-trace-granularity=1 -ftime-trace-verbose -xassembler d/a.cpp 2>&1 | \ -+// RUN: %clang -### -ftime-trace -ftime-trace=e -ftime-trace-granularity=1 -xassembler d/a.cpp 2>&1 | \ - // RUN: FileCheck %s --check-prefix=UNUSED - // UNUSED: warning: argument unused during compilation: '-ftime-trace' - // UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace=e' - // UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace-granularity=1' --// UNUSED-NEXT: warning: argument unused during compilation: '-ftime-trace-verbose' - // UNUSED-NOT: warning: - - template -diff -ruN --strip-trailing-cr a/clang/tools/driver/cc1_main.cpp b/clang/tools/driver/cc1_main.cpp ---- a/clang/tools/driver/cc1_main.cpp -+++ b/clang/tools/driver/cc1_main.cpp -@@ -241,8 +241,7 @@ - - if (!Clang->getFrontendOpts().TimeTracePath.empty()) { - llvm::timeTraceProfilerInitialize( -- Clang->getFrontendOpts().TimeTraceGranularity, Argv0, -- Clang->getFrontendOpts().TimeTraceVerbose); -+ Clang->getFrontendOpts().TimeTraceGranularity, Argv0); - } - // --print-supported-cpus takes priority over the actual compilation. - if (Clang->getFrontendOpts().PrintSupportedCPUs) -diff -ruN --strip-trailing-cr a/clang/unittests/Support/TimeProfilerTest.cpp b/clang/unittests/Support/TimeProfilerTest.cpp ---- a/clang/unittests/Support/TimeProfilerTest.cpp -+++ b/clang/unittests/Support/TimeProfilerTest.cpp -@@ -10,15 +10,11 @@ - #include "clang/Frontend/FrontendActions.h" - #include "clang/Lex/PreprocessorOptions.h" - --#include "llvm/ADT/StringMap.h" - #include "llvm/Support/JSON.h" --#include "llvm/Support/Path.h" - #include "llvm/Support/TimeProfiler.h" --#include "llvm/Support/VirtualFileSystem.h" - #include - - #include "gtest/gtest.h" --#include - - using namespace clang; - using namespace llvm; -@@ -27,8 +23,7 @@ - - // Should be called before testing. - void setupProfiler() { -- timeTraceProfilerInitialize(/*TimeTraceGranularity=*/0, "test", -- /*TimeTraceVerbose=*/true); -+ timeTraceProfilerInitialize(/*TimeTraceGranularity=*/0, "test"); - } - - // Should be called after `compileFromString()`. -@@ -43,24 +38,14 @@ - - // Returns true if code compiles successfully. - // We only parse AST here. This is enough for constexpr evaluation. --bool compileFromString(StringRef Code, StringRef Standard, StringRef File, -- llvm::StringMap Headers = {}) { -+bool compileFromString(StringRef Code, StringRef Standard, StringRef FileName) { - CompilerInstance Compiler; - Compiler.createDiagnostics(); - -- llvm::IntrusiveRefCntPtr FS( -- new llvm::vfs::InMemoryFileSystem()); -- FS->addFile(File, 0, MemoryBuffer::getMemBuffer(Code)); -- for (const auto &Header : Headers) { -- FS->addFile(Header.getKey(), 0, -- MemoryBuffer::getMemBuffer(Header.getValue())); -- } -- llvm::IntrusiveRefCntPtr Files( -- new FileManager(FileSystemOptions(), FS)); -- Compiler.setFileManager(Files.get()); -- - auto Invocation = std::make_shared(); -- std::vector Args = {Standard.data(), File.data()}; -+ Invocation->getPreprocessorOpts().addRemappedFile( -+ FileName, MemoryBuffer::getMemBuffer(Code).release()); -+ const char *Args[] = {Standard.data(), FileName.data()}; - CompilerInvocation::CreateFromArgs(*Invocation, Args, - Compiler.getDiagnostics()); - Compiler.setInvocation(std::move(Invocation)); -@@ -75,28 +60,13 @@ - return Compiler.ExecuteAction(Action); - } - --std::string GetMetadata(json::Object *Event) { -- std::string Metadata; -- llvm::raw_string_ostream OS(Metadata); -- if (json::Object *Args = Event->getObject("args")) { -- if (auto Detail = Args->getString("detail")) -- OS << Detail; -- // Use only filename to not include os-specific path separators. -- if (auto File = Args->getString("file")) -- OS << ", " << llvm::sys::path::filename(*File); -- if (auto Line = Args->getInteger("line")) -- OS << ":" << *Line; -- } -- return Metadata; --} -- - // Returns pretty-printed trace graph. - std::string buildTraceGraph(StringRef Json) { - struct EventRecord { - int64_t TimestampBegin; - int64_t TimestampEnd; -- std::string Name; -- std::string Metadata; -+ StringRef Name; -+ StringRef Detail; - }; - std::vector Events; - -@@ -111,13 +81,10 @@ - int64_t TimestampBegin = TraceEventObj->getInteger("ts").value_or(0); - int64_t TimestampEnd = - TimestampBegin + TraceEventObj->getInteger("dur").value_or(0); -- std::string Name = TraceEventObj->getString("name").value_or("").str(); -- std::string Metadata = GetMetadata(TraceEventObj); -- -- // Source events are asynchronous events and may not perfectly nest the -- // synchronous events. Skip testing them. -- if (Name == "Source") -- continue; -+ StringRef Name = TraceEventObj->getString("name").value_or(""); -+ StringRef Detail = ""; -+ if (json::Object *Args = TraceEventObj->getObject("args")) -+ Detail = Args->getString("detail").value_or(""); - - // This is a "summary" event, like "Total PerformPendingInstantiations", - // skip it -@@ -125,7 +92,7 @@ - continue; - - Events.emplace_back( -- EventRecord{TimestampBegin, TimestampEnd, Name, Metadata}); -+ EventRecord{TimestampBegin, TimestampEnd, Name, Detail}); - } - - // There can be nested events that are very fast, for example: -@@ -165,9 +132,9 @@ - Stream << "| "; - } - Stream.write(Event.Name.data(), Event.Name.size()); -- if (!Event.Metadata.empty()) { -+ if (!Event.Detail.empty()) { - Stream << " ("; -- Stream.write(Event.Metadata.data(), Event.Metadata.size()); -+ Stream.write(Event.Detail.data(), Event.Detail.size()); - Stream << ")"; - } - Stream << "\n"; -@@ -178,7 +145,7 @@ - } // namespace - - TEST(TimeProfilerTest, ConstantEvaluationCxx20) { -- std::string Code = R"( -+ constexpr StringRef Code = R"( - void print(double value); - - namespace slow_namespace { -@@ -208,7 +175,8 @@ - setupProfiler(); - ASSERT_TRUE(compileFromString(Code, "-std=c++20", "test.cc")); - std::string Json = teardownProfiler(); -- ASSERT_EQ(R"( -+ std::string TraceGraph = buildTraceGraph(Json); -+ ASSERT_TRUE(TraceGraph == R"( - Frontend - | ParseDeclarationOrFunctionDefinition (test.cc:2:1) - | ParseDeclarationOrFunctionDefinition (test.cc:6:1) -@@ -234,54 +202,14 @@ - | ParseDeclarationOrFunctionDefinition (test.cc:25:1) - | | EvaluateAsInitializer (slow_init_list) - | PerformPendingInstantiations --)", -- buildTraceGraph(Json)); --} -- --TEST(TimeProfilerTest, TemplateInstantiations) { -- std::string B_H = R"( -- template -- T fooB(T t) { -- return T(); -- } -+)"); - -- #define MacroTemp(x) template void foo##x(T) { T(); } -- )"; -- -- std::string A_H = R"( -- #include "b.h" -- -- MacroTemp(MTA) -- -- template -- void fooA(T t) { fooB(t); fooMTA(t); } -- )"; -- std::string Code = R"( -- #include "a.h" -- void user() { fooA(0); } -- )"; -- -- setupProfiler(); -- ASSERT_TRUE(compileFromString(Code, "-std=c++20", "test.cc", -- /*Headers=*/{{"a.h", A_H}, {"b.h", B_H}})); -- std::string Json = teardownProfiler(); -- ASSERT_EQ(R"( --Frontend --| ParseFunctionDefinition (fooB) --| ParseFunctionDefinition (fooMTA) --| ParseFunctionDefinition (fooA) --| ParseDeclarationOrFunctionDefinition (test.cc:3:5) --| | ParseFunctionDefinition (user) --| PerformPendingInstantiations --| | InstantiateFunction (fooA, a.h:7) --| | | InstantiateFunction (fooB, b.h:3) --| | | InstantiateFunction (fooMTA, a.h:4) --)", -- buildTraceGraph(Json)); -+ // NOTE: If this test is failing, run this test with -+ // `llvm::errs() << TraceGraph;` and change the assert above. - } - - TEST(TimeProfilerTest, ConstantEvaluationC99) { -- std::string Code = R"( -+ constexpr StringRef Code = R"( - struct { - short quantval[4]; // 3rd line - } value; -@@ -290,12 +218,15 @@ - setupProfiler(); - ASSERT_TRUE(compileFromString(Code, "-std=c99", "test.c")); - std::string Json = teardownProfiler(); -- ASSERT_EQ(R"( -+ std::string TraceGraph = buildTraceGraph(Json); -+ ASSERT_TRUE(TraceGraph == R"( - Frontend - | ParseDeclarationOrFunctionDefinition (test.c:2:1) - | | isIntegerConstantExpr () - | | EvaluateKnownConstIntCheckOverflow () - | PerformPendingInstantiations --)", -- buildTraceGraph(Json)); -+)"); -+ -+ // NOTE: If this test is failing, run this test with -+ // `llvm::errs() << TraceGraph;` and change the assert above. - } -diff -ruN --strip-trailing-cr a/lld/test/MachO/reproduce-thin-archive-objc.s b/lld/test/MachO/reproduce-thin-archive-objc.s ---- a/lld/test/MachO/reproduce-thin-archive-objc.s -+++ b/lld/test/MachO/reproduce-thin-archive-objc.s -@@ -4,20 +4,19 @@ - ## during linking. However, we need to iterate over all members for -ObjC, check that we don't - ## crash when we encounter a missing member. - --# RUN: rm -rf %t; mkdir %t --# RUN: sed s/SYM/_main/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o %t/main.o --# RUN: sed s/SYM/_unused/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o %t/unused.o -+# RUN: rm -rf %t && mkdir %t && cd %t -+# RUN: sed s/SYM/_main/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o main.o -+# RUN: sed s/SYM/_unused/ %s | llvm-mc -filetype=obj -triple=x86_64-apple-macos -o unused.o - --# RUN: cd %t; llvm-ar rcsT unused.a unused.o; rm unused.o -+# RUN: llvm-ar rcsT unused.a unused.o; rm unused.o - ## FIXME: Absolute paths don't end up relativized in the repro file. - - # RUN: %no-fatal-warnings-lld %t/main.o %t/unused.a -ObjC -o /dev/null 2>&1 \ - # RUN: | FileCheck %s --check-prefix=WARN - --# RUN: %lld %t/main.o %t/unused.a -ObjC --no-warn-thin-archive-missing-members -o /dev/null \ --# RUN: | FileCheck %s --implicit-check-not 'warning' --allow-empty -+# RUN: %lld main.o unused.a -ObjC --no-warn-thin-archive-missing-members 2>&1 | count 0 - --# WARN: ld64.lld: warning: {{.*}}unused.a: -ObjC failed to open archive member: 'unused.o' -+# WARN: warning: {{.*}}unused.a: -ObjC failed to open archive member: 'unused.o' - - .text - .globl SYM -diff -ruN --strip-trailing-cr a/llvm/include/llvm/Support/TimeProfiler.h b/llvm/include/llvm/Support/TimeProfiler.h ---- a/llvm/include/llvm/Support/TimeProfiler.h -+++ b/llvm/include/llvm/Support/TimeProfiler.h -@@ -83,28 +83,16 @@ - - class raw_pwrite_stream; - --struct TimeTraceMetadata { -- std::string Detail; -- // Source file and line number information for the event. -- std::string File; -- int Line; -- -- bool isEmpty() const { return Detail.empty() && File.empty(); } --}; -- - struct TimeTraceProfiler; - TimeTraceProfiler *getTimeTraceProfilerInstance(); - --bool isTimeTraceVerbose(); -- - struct TimeTraceProfilerEntry; - - /// Initialize the time trace profiler. - /// This sets up the global \p TimeTraceProfilerInstance - /// variable to be the profiler instance. - void timeTraceProfilerInitialize(unsigned TimeTraceGranularity, -- StringRef ProcName, -- bool TimeTraceVerbose = false); -+ StringRef ProcName); - - /// Cleanup the time trace profiler, if it was initialized. - void timeTraceProfilerCleanup(); -@@ -140,10 +128,6 @@ - timeTraceProfilerBegin(StringRef Name, - llvm::function_ref Detail); - --TimeTraceProfilerEntry * --timeTraceProfilerBegin(StringRef Name, -- llvm::function_ref MetaData); -- - /// Manually begin a time section, with the given \p Name and \p Detail. - /// This starts Async Events having \p Name as a category which is shown - /// separately from other traces. See -@@ -180,11 +164,6 @@ - if (getTimeTraceProfilerInstance() != nullptr) - Entry = timeTraceProfilerBegin(Name, Detail); - } -- TimeTraceScope(StringRef Name, -- llvm::function_ref Metadata) { -- if (getTimeTraceProfilerInstance() != nullptr) -- Entry = timeTraceProfilerBegin(Name, Metadata); -- } - ~TimeTraceScope() { - if (getTimeTraceProfilerInstance() != nullptr) - timeTraceProfilerEnd(Entry); -diff -ruN --strip-trailing-cr a/llvm/lib/Support/TimeProfiler.cpp b/llvm/lib/Support/TimeProfiler.cpp ---- a/llvm/lib/Support/TimeProfiler.cpp -+++ b/llvm/lib/Support/TimeProfiler.cpp -@@ -73,20 +73,12 @@ - const TimePointType Start; - TimePointType End; - const std::string Name; -- TimeTraceMetadata Metadata; -- -+ const std::string Detail; - const bool AsyncEvent = false; - TimeTraceProfilerEntry(TimePointType &&S, TimePointType &&E, std::string &&N, - std::string &&Dt, bool Ae) -- : Start(std::move(S)), End(std::move(E)), Name(std::move(N)), Metadata(), -- AsyncEvent(Ae) { -- Metadata.Detail = std::move(Dt); -- } -- -- TimeTraceProfilerEntry(TimePointType &&S, TimePointType &&E, std::string &&N, -- TimeTraceMetadata &&Mt, bool Ae) - : Start(std::move(S)), End(std::move(E)), Name(std::move(N)), -- Metadata(std::move(Mt)), AsyncEvent(Ae) {} -+ Detail(std::move(Dt)), AsyncEvent(Ae) {} - - // Calculate timings for FlameGraph. Cast time points to microsecond precision - // rather than casting duration. This avoids truncation issues causing inner -@@ -105,12 +97,10 @@ - }; - - struct llvm::TimeTraceProfiler { -- TimeTraceProfiler(unsigned TimeTraceGranularity = 0, StringRef ProcName = "", -- bool TimeTraceVerbose = false) -+ TimeTraceProfiler(unsigned TimeTraceGranularity = 0, StringRef ProcName = "") - : BeginningOfTime(system_clock::now()), StartTime(ClockType::now()), - ProcName(ProcName), Pid(sys::Process::getProcessId()), -- Tid(llvm::get_threadid()), TimeTraceGranularity(TimeTraceGranularity), -- TimeTraceVerbose(TimeTraceVerbose) { -+ Tid(llvm::get_threadid()), TimeTraceGranularity(TimeTraceGranularity) { - llvm::get_thread_name(ThreadName); - } - -@@ -123,15 +113,6 @@ - return Stack.back().get(); - } - -- TimeTraceProfilerEntry * -- begin(std::string Name, llvm::function_ref Metadata, -- bool AsyncEvent = false) { -- Stack.emplace_back(std::make_unique( -- ClockType::now(), TimePointType(), std::move(Name), Metadata(), -- AsyncEvent)); -- return Stack.back().get(); -- } -- - void end() { - assert(!Stack.empty() && "Must call begin() first"); - end(*Stack.back()); -@@ -203,15 +184,8 @@ - J.attribute("dur", DurUs); - } - J.attribute("name", E.Name); -- if (!E.Metadata.isEmpty()) { -- J.attributeObject("args", [&] { -- if (!E.Metadata.Detail.empty()) -- J.attribute("detail", E.Metadata.Detail); -- if (!E.Metadata.File.empty()) -- J.attribute("file", E.Metadata.File); -- if (E.Metadata.Line > 0) -- J.attribute("line", E.Metadata.Line); -- }); -+ if (!E.Detail.empty()) { -+ J.attributeObject("args", [&] { J.attribute("detail", E.Detail); }); - } - }); - -@@ -333,25 +307,14 @@ - - // Minimum time granularity (in microseconds) - const unsigned TimeTraceGranularity; -- -- // Make time trace capture verbose event details (e.g. source filenames). This -- // can increase the size of the output by 2-3 times. -- const bool TimeTraceVerbose; - }; - --bool llvm::isTimeTraceVerbose() { -- return getTimeTraceProfilerInstance() && -- getTimeTraceProfilerInstance()->TimeTraceVerbose; --} -- - void llvm::timeTraceProfilerInitialize(unsigned TimeTraceGranularity, -- StringRef ProcName, -- bool TimeTraceVerbose) { -+ StringRef ProcName) { - assert(TimeTraceProfilerInstance == nullptr && - "Profiler should not be initialized"); - TimeTraceProfilerInstance = new TimeTraceProfiler( -- TimeTraceGranularity, llvm::sys::path::filename(ProcName), -- TimeTraceVerbose); -+ TimeTraceGranularity, llvm::sys::path::filename(ProcName)); - } - - // Removes all TimeTraceProfilerInstances. -@@ -418,14 +381,6 @@ - return nullptr; - } - --TimeTraceProfilerEntry * --llvm::timeTraceProfilerBegin(StringRef Name, -- llvm::function_ref Metadata) { -- if (TimeTraceProfilerInstance != nullptr) -- return TimeTraceProfilerInstance->begin(std::string(Name), Metadata, false); -- return nullptr; --} -- - TimeTraceProfilerEntry *llvm::timeTraceAsyncProfilerBegin(StringRef Name, - StringRef Detail) { - if (TimeTraceProfilerInstance != nullptr) +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +@@ -675,6 +675,7 @@ + deps = [ + ":__support_common", + ":__support_cpp_type_traits", ++ ":__support_fputil_dyadic_float", + ":__support_fputil_fenv_impl", + ":__support_fputil_fp_bits", + ":__support_macros_optimization", +@@ -1089,7 +1090,7 @@ + ":__support_macros_optimization", + ":__support_osutil_syscall", + ":types_pid_t", +- ] ++ ], + ) + + libc_support_library( diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 2949b73a155af1..a6b1b06abe37c5 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "84658fb82b67fc22ecba1560d0cddd09f9104178" - LLVM_SHA256 = "b4a50d36a8ab0284f7022f61bbf07a2fb3ea25c6bb2cc422d2418c23b61366da" + LLVM_COMMIT = "58fb51492d9669525662fa269295d85537968569" + LLVM_SHA256 = "f6cac3f3f562a7bd3a36a828df2960a1ebc2cd6237f4cb95a66f1bd16e918ef9" tf_http_archive( name = name, From f9d67fccd97ee1a1349bb92d48f0c16d7149d0c9 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 25 Jul 2024 14:47:15 -0700 Subject: [PATCH 172/376] [tsl] Remove platform:types from mutex dependencies PiperOrigin-RevId: 656100657 --- third_party/tsl/tsl/platform/default/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/tsl/tsl/platform/default/BUILD b/third_party/tsl/tsl/platform/default/BUILD index 5e646dcbd852d7..01cf593888c077 100644 --- a/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/tsl/tsl/platform/default/BUILD @@ -303,7 +303,6 @@ cc_library( "//tsl/platform", "//tsl/platform:macros", "//tsl/platform:thread_annotations", - "//tsl/platform:types", "@nsync//:nsync_cpp", ], ) From bb9fa7d66e965280a8638f235e214b009f9aa4c5 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 25 Jul 2024 15:01:00 -0700 Subject: [PATCH 173/376] Integrate StableHLO at openxla/stablehlo@8555db77 PiperOrigin-RevId: 656105462 --- third_party/stablehlo/temporary.patch | 57 --------------------------- third_party/stablehlo/workspace.bzl | 4 +- 2 files changed, 2 insertions(+), 59 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 3cae7dc292dc85..8b137891791fe9 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,58 +1 @@ -diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp ---- stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp -+++ stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp -@@ -66,17 +66,25 @@ - ->getResult(0); - } - --std::optional scalarToTensor(OpBuilder &builder, Type /*type*/, -+std::optional scalarToTensor(OpBuilder& builder, Type type, - ValueRange inputs, Location loc) { - assert(inputs.size() == 1); -- if (llvm::isa(inputs.front().getType())) { -+ if (mlir::isa(inputs.front().getType())) { - return std::nullopt; - } -- return builder -- .create( -- loc, RankedTensorType::get({}, inputs.front().getType()), -- inputs.front()) -- .getResult(); -+ Value result = -+ builder -+ .create( -+ loc, RankedTensorType::get({}, inputs.front().getType()), -+ inputs.front()) -+ .getResult(); -+ // Convert to a signed integer if necessary. -+ Type elementType = mlir::getElementTypeOrSelf(type); -+ if (elementType.isInteger() && !elementType.isSignlessInteger()) { -+ result = builder.create(loc, type, result) -+ ->getResult(0); -+ } -+ return result; - } - - } // namespace -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp -@@ -1270,12 +1270,13 @@ - OperationState state(op->getLoc(), op->getName().getStringRef(), operands, - newResultTypes, op->getAttrs(), op->getSuccessors()); - for (Region ®ion : op->getRegions()) { -- Region &newRegion = *state.addRegion(); -- rewriter.inlineRegionBefore(region, newRegion, newRegion.begin()); -- if (failed( -- rewriter.convertRegionTypes(&newRegion, *getTypeConverter()))) { -+ auto newRegion = std::make_unique(op); -+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); -+ if (failed(rewriter.convertRegionTypes(newRegion.get(), -+ *getTypeConverter()))) { - return failure(); - } -+ state.addRegion(std::move(newRegion)); - } - Operation *newOp = rewriter.create(state); - rewriter.replaceOp(op, newOp); diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 03e69998dfc9ee..f9c14a65d4dbb3 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "840c41ceb0d13800d286a9d76d8ad00d97838d9e" - STABLEHLO_SHA256 = "f2f92695ecdb2449a3d2316015a37301c1e4768315b9e753e18b4759eebb67e8" + STABLEHLO_COMMIT = "8555db77763fadbd6be83df0a5532828bc419cba" + STABLEHLO_SHA256 = "666a88d94e0f1b36e9e5b25411521b878320c61983214859b4e419f36acbf332" # LINT.ThenChange(Google-internal path) tf_http_archive( From 2a3141886b11e1f4bd2d30912881db3743d48ac8 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 25 Jul 2024 15:04:47 -0700 Subject: [PATCH 174/376] Create cuda::ToStatus helper function to translate CUresult codes into absl::Status objects. This creates common handling of CUresults, and the elimination of RETURN_IF_CUDA_ERROR-style macros in favor of using the common TF_CHECK_OK/TF_RETURN_IF_ERROR ones. PiperOrigin-RevId: 656106821 --- xla/stream_executor/cuda/BUILD | 18 + xla/stream_executor/cuda/cuda_asm_compiler.cc | 34 +- xla/stream_executor/cuda/cuda_driver.cc | 1182 +++++++---------- xla/stream_executor/cuda/cuda_driver.h | 24 +- xla/stream_executor/cuda/cuda_driver_test.cc | 17 +- xla/stream_executor/cuda/cuda_status.cc | 51 + xla/stream_executor/cuda/cuda_status.h | 42 + xla/stream_executor/gpu/BUILD | 1 + .../gpu/gpu_cudamallocasync_allocator.cc | 54 +- 9 files changed, 656 insertions(+), 767 deletions(-) create mode 100644 xla/stream_executor/cuda/cuda_status.cc create mode 100644 xla/stream_executor/cuda/cuda_status.h diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index a33da00c418d93..48592a1a92656b 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -143,6 +143,7 @@ cuda_only_cc_library( hdrs = ["cuda_driver.h"], deps = [ ":cuda_diagnostics", # buildcleaner: keep + ":cuda_status", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", @@ -164,10 +165,24 @@ cuda_only_cc_library( "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:macros", "@tsl//tsl/platform:numbers", "@tsl//tsl/platform:stacktrace", + "@tsl//tsl/platform:status", + ], +) + +cuda_only_cc_library( + name = "cuda_status", + srcs = ["cuda_status.cc"], + hdrs = ["cuda_status.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", ], ) @@ -217,9 +232,11 @@ xla_test( ], deps = [ ":cuda_driver", + ":cuda_status", "//xla/stream_executor/gpu:gpu_driver_header", "@com_google_absl//absl/log", "@local_config_cuda//cuda:cuda_headers", + "@tsl//tsl/platform:status", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], @@ -721,6 +738,7 @@ cuda_only_cc_library( ]), deps = [ ":cuda_driver", + ":cuda_status", ":ptx_compiler", ":ptx_compiler_support", "//xla:status_macros", diff --git a/xla/stream_executor/cuda/cuda_asm_compiler.cc b/xla/stream_executor/cuda/cuda_asm_compiler.cc index caea22789a151d..01aa15313c2cd0 100644 --- a/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -48,6 +48,7 @@ limitations under the License. #include "absl/types/span.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/status_macros.h" +#include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/cuda/ptx_compiler.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" #include "xla/stream_executor/device_description.h" @@ -496,19 +497,6 @@ absl::StatusOr> BundleGpuAsm( return std::vector(result_blob.begin(), result_blob.end()); } -#define RETURN_IF_CUDA_ERROR(expr) \ - do { \ - CUresult _status = expr; \ - if (!ABSL_PREDICT_TRUE(_status == CUDA_SUCCESS)) { \ - const char* error_string; \ - cuGetErrorString(_status, &error_string); \ - std::ostringstream oss; \ - oss << error_string << "\nin " << __FILE__ << "(" << __LINE__ << "): '" \ - << #expr << "'"; \ - return absl::UnknownError(oss.str().c_str()); \ - } \ - } while (false) - static absl::StatusOr FindNvlinkExecutable( std::string_view preferred_cuda_dir) { static constexpr ToolVersion kMinimumNvlinkVersion{11, 8, 0}; @@ -630,24 +618,26 @@ absl::StatusOr> LinkGpuAsm( static_assert(sizeof(options) / sizeof(options[0]) == sizeof(option_values) / sizeof(option_values[0])); - RETURN_IF_CUDA_ERROR(cuLinkCreate(sizeof(options) / sizeof(options[0]), - options, option_values, &link_state)); + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuLinkCreate(sizeof(options) / sizeof(options[0]), options, + option_values, &link_state))); for (auto& image : images) { - auto status = cuLinkAddData(link_state, CU_JIT_INPUT_CUBIN, - static_cast(image.bytes.data()), - image.bytes.size(), "", 0, nullptr, nullptr); - if (status != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuLinkAddData( + link_state, CU_JIT_INPUT_CUBIN, static_cast(image.bytes.data()), + image.bytes.size(), "", 0, nullptr, nullptr)); + if (!status.ok()) { LOG(ERROR) << "cuLinkAddData fails. This is usually caused by stale " "driver version."; + return status; } - RETURN_IF_CUDA_ERROR(status); } void* cubin_out; size_t cubin_size; - RETURN_IF_CUDA_ERROR(cuLinkComplete(link_state, &cubin_out, &cubin_size)); + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuLinkComplete(link_state, &cubin_out, &cubin_size))); std::vector cubin(static_cast(cubin_out), static_cast(cubin_out) + cubin_size); - RETURN_IF_CUDA_ERROR(cuLinkDestroy(link_state)); + TF_RETURN_IF_ERROR(cuda::ToStatus(cuLinkDestroy(link_state))); return std::move(cubin); } diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index e498468f5e6753..763408e4c35e8b 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -43,44 +43,25 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" +#include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/numbers.h" #include "tsl/platform/stacktrace.h" +#include "tsl/platform/status.h" #include "tsl/platform/threadpool.h" -#define RETURN_IF_CUDA_RES_ERROR(expr, ...) \ - do { \ - CUresult _res = (expr); \ - if (ABSL_PREDICT_FALSE(_res != CUDA_SUCCESS)) { \ - if (_res == CUDA_ERROR_OUT_OF_MEMORY) \ - return absl::ResourceExhaustedError(absl::StrCat( \ - __VA_ARGS__, ":", ::stream_executor::gpu::ToString(_res))); \ - else \ - return absl::InternalError(absl::StrCat( \ - __VA_ARGS__, ": ", ::stream_executor::gpu::ToString(_res))); \ - } \ - } while (0) - -#define FAIL_IF_CUDA_RES_ERROR(expr, ...) \ - do { \ - CUresult _res = (expr); \ - if (ABSL_PREDICT_FALSE(_res != CUDA_SUCCESS)) { \ - LOG(FATAL) << absl::StrCat(__VA_ARGS__) << ": " \ - << ::stream_executor::gpu::ToString(_res); \ - } \ - } while (0) - namespace stream_executor { namespace gpu { -/* static */ absl::Mutex CreatedContexts::mu_{absl::kConstInit}; +absl::Mutex CreatedContexts::mu_{absl::kConstInit}; namespace { @@ -131,8 +112,8 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) { if (tls->depth == 0) { VLOG(3) << "ScopedActivateContext switching to " << cuda_context->device_ordinal(); - FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()), - "Failed setting context"); + TF_CHECK_OK(cuda::ToStatus(cuCtxSetCurrent(cuda_context->context()), + "Failed setting context")); tls->depth = 1; tls->device_ordinal = cuda_context->device_ordinal(); tls->context = cuda_context; @@ -151,8 +132,8 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* cuda_context) { to_restore_ = tls->context; // Set the context and update thread local. - FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(cuda_context->context()), - "Failed setting context"); + TF_CHECK_OK(cuda::ToStatus(cuCtxSetCurrent(cuda_context->context()), + "Failed setting context")); tls->device_ordinal = cuda_context->device_ordinal(); tls->context = cuda_context; } @@ -168,8 +149,8 @@ ScopedActivateContext::~ScopedActivateContext() { } // Set context and update thread local. - FAIL_IF_CUDA_RES_ERROR(cuCtxSetCurrent(to_restore_->context()), - "Failed setting context"); + TF_CHECK_OK(cuda::ToStatus(cuCtxSetCurrent(to_restore_->context()), + "Failed setting context")); tls->device_ordinal = to_restore_->device_ordinal(); tls->context = to_restore_; } @@ -226,19 +207,16 @@ std::string CUDAPointersToCanAccessString(CUdeviceptr from, CUdeviceptr to) { // Actually performs the work of CUDA initialization. Wrapped up in one-time // execution guard. static absl::Status InternalInit() { - CUresult res = cuInit(0 /* = flags */); - - if (res == CUDA_SUCCESS) { - return absl::OkStatus(); - } else if (res == CUDA_ERROR_SHARED_OBJECT_INIT_FAILED) { - VLOG(1) << "failed call to cuInit: " << ToString(res); - } else { - LOG(ERROR) << "failed call to cuInit: " << ToString(res); + absl::Status status = + cuda::ToStatus(cuInit(0 /* = flags */), "Failed call to cuInit"); + if (status.ok()) { + return status; } + LOG(ERROR) << "failed call to cuInit: " << status; + Diagnostician::LogDiagnosticInformation(); - return absl::AbortedError( - absl::StrCat("failed call to cuInit: ", ToString(res))); + return status; } // Synchronize with spinlocks. @@ -280,14 +258,13 @@ absl::StatusOr QueryEvent(GpuContext* context, CUevent event) { ScopedActivateContext activated{context}; CUresult res = cuEventQuery(event); if (res != CUDA_SUCCESS && res != CUDA_ERROR_NOT_READY) { - return absl::InternalError( - absl::StrFormat("failed to query event: %s", ToString(res))); + return cuda::ToStatus(res, ("failed to query event")); } return res; } -/* static */ absl::Status GpuDriver::Init() { +absl::Status GpuDriver::Init() { // Cached return value from calling InternalInit(), as cuInit need only be // called once, but GpuDriver::Init may be called many times. static absl::Status* init_retval = [] { @@ -296,41 +273,34 @@ absl::StatusOr QueryEvent(GpuContext* context, CUevent event) { return *init_retval; } -/* static */ absl::Status GpuDriver::GetDevice(int device_ordinal, - CUdevice* device) { - RETURN_IF_CUDA_RES_ERROR(cuDeviceGet(device, device_ordinal), - "Failed call to cuDeviceGet"); - return absl::OkStatus(); +absl::Status GpuDriver::GetDevice(int device_ordinal, CUdevice* device) { + return cuda::ToStatus(cuDeviceGet(device, device_ordinal), + "Failed call to cuDeviceGet"); } -/* static */ absl::Status GpuDriver::GetDeviceName(CUdevice device, - std::string* device_name) { +absl::Status GpuDriver::GetDeviceName(CUdevice device, + std::string* device_name) { static const size_t kCharLimit = 64; absl::InlinedVector chars(kCharLimit); - RETURN_IF_CUDA_RES_ERROR( - cuDeviceGetName(chars.begin(), kCharLimit - 1, device), - "Failed to get device name"); + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuDeviceGetName(chars.begin(), kCharLimit - 1, device), + "Failed to get device name")); chars[kCharLimit - 1] = '\0'; *device_name = chars.begin(); return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::CreateContext(int device_ordinal, - CUdevice device, - GpuContext** context) { +absl::Status GpuDriver::CreateContext(int device_ordinal, CUdevice device, + GpuContext** context) { *context = nullptr; int flags = GetFlagsFromEnv(); - CUresult res; - CUcontext former_context; - CUcontext new_context; - unsigned int former_primary_context_flags; int former_primary_context_is_active; - CHECK_EQ(CUDA_SUCCESS, - cuDevicePrimaryCtxGetState(device, &former_primary_context_flags, - &former_primary_context_is_active)); + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuDevicePrimaryCtxGetState(device, &former_primary_context_flags, + &former_primary_context_is_active))); if (former_primary_context_flags != flags) { if (former_primary_context_is_active) { LOG(ERROR) @@ -338,12 +308,15 @@ absl::StatusOr QueryEvent(GpuContext* context, CUevent event) { << former_primary_context_flags << ") than the desired flag set (" << flags << ")."; } else { - CHECK_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxSetFlags(device, flags)); + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuDevicePrimaryCtxSetFlags(device, flags))); } } - former_context = cuda::CurrentContextOrDie(); - res = cuDevicePrimaryCtxRetain(&new_context, device); + CUcontext former_context = cuda::CurrentContextOrDie(); + CUcontext new_context; + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuDevicePrimaryCtxRetain(&new_context, device))); if (former_context != nullptr) { CUdevice former_device; if (cuCtxGetDevice(&former_device) == CUDA_SUCCESS) { @@ -365,93 +338,77 @@ absl::StatusOr QueryEvent(GpuContext* context, CUevent event) { << former_context; } } - CHECK_EQ(CUDA_SUCCESS, cuCtxSetCurrent(former_context)); - - if (res == CUDA_SUCCESS) { - *context = CreatedContexts::Add(new_context, device_ordinal); - CHECK(*context != nullptr) - << "success in this call must entail non-null result"; - VLOG(2) << "created or reused context " << new_context - << " for this thread"; - return absl::OkStatus(); - } - - std::string message = - "failed call to cuDevicePrimaryCtxRetain: " + ToString(res); - if (res == CUDA_ERROR_OUT_OF_MEMORY) { - uint64_t total_memory; - if (GetDeviceTotalMemory(device, &total_memory)) { - absl::StrAppend(&message, "; total memory reported: ", total_memory); - } else { - absl::StrAppend(&message, "; could not query total memory"); - } - } + TF_RETURN_IF_ERROR(cuda::ToStatus(cuCtxSetCurrent(former_context))); - return absl::InternalError(message); + *context = CreatedContexts::Add(new_context, device_ordinal); + CHECK(*context != nullptr) + << "success in this call must entail non-null result"; + VLOG(2) << "created or reused context " << new_context << " for this thread"; + return absl::OkStatus(); } -/* static */ void GpuDriver::DestroyContext(GpuContext* context) { +void GpuDriver::DestroyContext(GpuContext* context) { if (context == nullptr) { return; } - CUresult res = cuCtxPushCurrent(context->context()); + auto status = cuda::ToStatus(cuCtxPushCurrent(context->context())); + if (!status.ok()) { + LOG(ERROR) << "failed to Push CUDA context; leaking: " << status; + } CUdevice device; cuCtxGetDevice(&device); cuCtxPopCurrent(nullptr); - res = cuDevicePrimaryCtxRelease(device); + status = cuda::ToStatus(cuDevicePrimaryCtxRelease(device)); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to release CUDA context; leaking: " << ToString(res); + if (!status.ok()) { + LOG(ERROR) << "failed to release CUDA context; leaking: " << status; } CreatedContexts::Remove(context->context()); } -/* static */ absl::Status GpuDriver::FuncGetAttribute( - CUfunction_attribute attribute, CUfunction func, int* attribute_value) { - RETURN_IF_CUDA_RES_ERROR(cuFuncGetAttribute(attribute_value, attribute, func), - "Failed to query kernel attribute: ", attribute); - return absl::OkStatus(); +absl::Status GpuDriver::FuncGetAttribute(CUfunction_attribute attribute, + CUfunction func, + int* attribute_value) { + return cuda::ToStatus( + cuFuncGetAttribute(attribute_value, attribute, func), + absl::StrCat("Failed to query kernel attribute: ", attribute)); } -/* static */ absl::Status GpuDriver::FuncSetCacheConfig( - CUfunction function, CUfunc_cache cache_config) { - RETURN_IF_CUDA_RES_ERROR(cuFuncSetCacheConfig(function, cache_config), - "Failed to set CUDA kernel cache config"); - return absl::OkStatus(); +absl::Status GpuDriver::FuncSetCacheConfig(CUfunction function, + CUfunc_cache cache_config) { + return cuda::ToStatus(cuFuncSetCacheConfig(function, cache_config), + "Failed to set CUDA kernel cache config"); } -/* static */ absl::StatusOr -GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { +absl::StatusOr GpuDriver::ContextGetSharedMemConfig( + GpuContext* context) { CUsharedconfig shared_mem_config; ScopedActivateContext activation(context); - RETURN_IF_CUDA_RES_ERROR(cuCtxGetSharedMemConfig(&shared_mem_config), - "Failed to get shared memory config"); + TF_RETURN_IF_ERROR(cuda::ToStatus(cuCtxGetSharedMemConfig(&shared_mem_config), + "Failed to get shared memory config")); return shared_mem_config; } -/* static */ absl::Status GpuDriver::ContextSetSharedMemConfig( +absl::Status GpuDriver::ContextSetSharedMemConfig( GpuContext* context, CUsharedconfig shared_mem_config) { ScopedActivateContext activation(context); - RETURN_IF_CUDA_RES_ERROR(cuCtxSetSharedMemConfig(shared_mem_config), - "Failed to set shared memory config"); - return absl::OkStatus(); + return cuda::ToStatus(cuCtxSetSharedMemConfig(shared_mem_config), + "Failed to set shared memory config"); } -/* static */ absl::Status GpuDriver::CreateGraph(CUgraph* graph) { +absl::Status GpuDriver::CreateGraph(CUgraph* graph) { VLOG(2) << "Create new CUDA graph"; - RETURN_IF_CUDA_RES_ERROR(cuGraphCreate(graph, /*flags=*/0), - "Failed to create CUDA graph"); + TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(graph, /*flags=*/0), + "Failed to create CUDA graph")); VLOG(2) << "Created CUDA graph " << *graph; return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::DestroyGraph(CUgraph graph) { +absl::Status GpuDriver::DestroyGraph(CUgraph graph) { VLOG(2) << "Destroy CUDA graph " << graph; - RETURN_IF_CUDA_RES_ERROR(cuGraphDestroy(graph), - "Failed to destroy CUDA graph"); - return absl::OkStatus(); + return cuda::ToStatus(cuGraphDestroy(graph), "Failed to destroy CUDA graph"); } static std::string_view StreamCaptureModeToString( @@ -466,8 +423,8 @@ static std::string_view StreamCaptureModeToString( } } -/* static */ absl::Status GpuDriver::StreamBeginCapture( - CUstream stream, StreamCaptureMode mode) { +absl::Status GpuDriver::StreamBeginCapture(CUstream stream, + StreamCaptureMode mode) { CUstreamCaptureMode cu_mode; switch (mode) { case StreamCaptureMode::kGlobal: @@ -483,13 +440,13 @@ static std::string_view StreamCaptureModeToString( VLOG(2) << "Beginning stream " << stream << " capture in " << StreamCaptureModeToString(mode) << " mode"; - RETURN_IF_CUDA_RES_ERROR(cuStreamBeginCapture(stream, cu_mode), - "Failed to begin stream capture"); - return absl::OkStatus(); + return cuda::ToStatus(cuStreamBeginCapture(stream, cu_mode), + "Failed to begin stream capture"); } -/* static */ absl::Status GpuDriver::StreamBeginCaptureToGraph( - CUstream stream, CUgraph graph, StreamCaptureMode mode) { +absl::Status GpuDriver::StreamBeginCaptureToGraph(CUstream stream, + CUgraph graph, + StreamCaptureMode mode) { CUstreamCaptureMode cu_mode; switch (mode) { case StreamCaptureMode::kGlobal: @@ -506,31 +463,27 @@ static std::string_view StreamCaptureModeToString( #if CUDA_VERSION >= 12030 VLOG(2) << "Beginning stream " << stream << " capture in " << StreamCaptureModeToString(mode) << " mode to graph " << graph; - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuStreamBeginCaptureToGraph(stream, graph, /*dependencies=*/nullptr, /*dependencyData=*/nullptr, /*numDependencies=*/0, cu_mode), "Failed to begin stream capture to graph"); - return absl::OkStatus(); #else return absl::UnimplementedError( "StreamBeginCaptureToGraph is not implemented"); #endif // CUDA_VERSION >= 12030 } -/* static */ absl::Status GpuDriver::StreamEndCapture(CUstream stream, - CUgraph* graph) { +absl::Status GpuDriver::StreamEndCapture(CUstream stream, CUgraph* graph) { VLOG(2) << "End stream " << stream << " capture"; - RETURN_IF_CUDA_RES_ERROR(cuStreamEndCapture(stream, graph), - "Failed to end stream capture"); - - return absl::OkStatus(); + return cuda::ToStatus(cuStreamEndCapture(stream, graph), + "Failed to end stream capture"); } -/* static */ absl::Status GpuDriver::GraphInstantiate( - CUgraphExec* exec, CUgraph graph, const GraphInstantiateFlags& flags) { +absl::Status GpuDriver::GraphInstantiate(CUgraphExec* exec, CUgraph graph, + const GraphInstantiateFlags& flags) { VLOG(2) << "Instantiate CUDA executable graph from graph " << graph << " (" << "auto_free_on_launch=" << flags.auto_free_on_launch << ", " << "device_launch=" << flags.device_launch << ", " @@ -547,39 +500,33 @@ static std::string_view StreamCaptureModeToString( cu_flags |= CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH; if (flags.upload) cu_flags |= CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD; - RETURN_IF_CUDA_RES_ERROR(cuGraphInstantiate(exec, graph, cu_flags), - "Failed to instantiate CUDA graph"); + return cuda::ToStatus(cuGraphInstantiate(exec, graph, cu_flags), + "Failed to instantiate CUDA graph"); #else - RETURN_IF_CUDA_RES_ERROR(cuGraphInstantiate(exec, graph, nullptr, nullptr, 0), + return (cuda::ToStatus(cuGraphInstantiate(exec, graph, nullptr, nullptr, 0), "Failed to instantiate CUDA graph"); #endif // CUDA_VERSION >= 12000 - - return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::GraphLaunch(CUgraphExec exec, - CUstream stream) { +absl::Status GpuDriver::GraphLaunch(CUgraphExec exec, CUstream stream) { VLOG(2) << "Launching CUDA executable graph " << exec << " on a stream " << stream; - RETURN_IF_CUDA_RES_ERROR(cuGraphLaunch(exec, stream), - "Failed to launch CUDA graph"); - return absl::OkStatus(); + return cuda::ToStatus(cuGraphLaunch(exec, stream), + "Failed to launch CUDA graph"); } -/* static */ absl::Status GpuDriver::GraphNodeSetEnabled(CUgraphExec exec, - CUgraphNode node, - bool enabled) { +absl::Status GpuDriver::GraphNodeSetEnabled(CUgraphExec exec, CUgraphNode node, + bool enabled) { // Node is enabled if value != 0, otherwise the node is disabled. unsigned value = enabled ? 1 : 0; VLOG(2) << "Set CUDA executable graph " << exec << " node " << node << " enabled flag to " << value; - RETURN_IF_CUDA_RES_ERROR(cuGraphNodeSetEnabled(exec, node, value), - "Failed to set CUDA graph node enabled flag"); - return absl::OkStatus(); + return cuda::ToStatus(cuGraphNodeSetEnabled(exec, node, value), + "Failed to set CUDA graph node enabled flag"); } -/* static */ absl::Status GpuDriver::GraphExecUpdate( - CUgraphExec exec, CUgraph graph, GraphExecUpdateResultInfo* result) { +absl::Status GpuDriver::GraphExecUpdate(CUgraphExec exec, CUgraph graph, + GraphExecUpdateResultInfo* result) { VLOG(2) << "Update CUDA graph executable " << exec << " with graph " << graph; #if CUDA_VERSION >= 12000 @@ -632,17 +579,15 @@ static std::string_view StreamCaptureModeToString( default: return absl::InternalError("Unknown graph update result"); } - - RETURN_IF_CUDA_RES_ERROR(err_code, "Failed to update CUDA graph"); - return absl::OkStatus(); + return cuda::ToStatus(err_code, "Failed to update CUDA graph"); } -/* static */ absl::StatusOr -GpuDriver::GraphNodeGetType(CUgraphNode node) { +absl::StatusOr GpuDriver::GraphNodeGetType( + CUgraphNode node) { CUgraphNodeType cu_node_type; memset(&cu_node_type, 0, sizeof(cu_node_type)); - RETURN_IF_CUDA_RES_ERROR(cuGraphNodeGetType(node, &cu_node_type), - "Failed to get CUDA graph node type"); + TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphNodeGetType(node, &cu_node_type), + "Failed to get CUDA graph node type")); switch (cu_node_type) { case CU_GRAPH_NODE_TYPE_KERNEL: @@ -687,33 +632,32 @@ GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) { std::vector dependencies; size_t num_dependencies = 0; - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuGraphNodeGetDependencies(node, nullptr, &num_dependencies), - "Failed to get CUDA graph node depedencies size"); + "Failed to get CUDA graph node depedencies size")); dependencies.resize(num_dependencies, nullptr); - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuGraphNodeGetDependencies(node, dependencies.data(), &num_dependencies), - "Failed to get CUDA graph node depedencies"); + "Failed to get CUDA graph node depedencies")); return dependencies; } -/* static */ absl::Status GpuDriver::DestroyGraphExec(CUgraphExec exec) { +absl::Status GpuDriver::DestroyGraphExec(CUgraphExec exec) { VLOG(2) << "Destroying CUDA executable graph " << exec; - RETURN_IF_CUDA_RES_ERROR(cuGraphExecDestroy(exec), - "Failed to destroy CUDA executable graph"); - return absl::OkStatus(); + return cuda::ToStatus(cuGraphExecDestroy(exec), + "Failed to destroy CUDA executable graph"); } -/* static */ absl::StatusOr GpuDriver::GraphDebugDotPrint( +absl::StatusOr GpuDriver::GraphDebugDotPrint( CUgraph graph, const char* path, bool return_printed_graph) { #if CUDA_VERSION >= 12000 VLOG(2) << "Print CUDA graph " << graph << " debug dot file to " << path; int flags = CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE; - RETURN_IF_CUDA_RES_ERROR(cuGraphDebugDotPrint(graph, path, flags), - "Failed to print gpu graph debug file"); + TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphDebugDotPrint(graph, path, flags), + "Failed to print gpu graph debug file")); if (return_printed_graph) { std::string data; @@ -728,25 +672,23 @@ GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) { return std::string(path); } -/* static */ absl::Status GpuDriver::DeviceGraphMemTrim(CUdevice device) { +absl::Status GpuDriver::DeviceGraphMemTrim(CUdevice device) { VLOG(2) << "Trim CUDA device graph memory " << device; - RETURN_IF_CUDA_RES_ERROR(cuDeviceGraphMemTrim(device), - "Failed to trim device graph memory"); - return absl::OkStatus(); + return cuda::ToStatus(cuDeviceGraphMemTrim(device), + "Failed to trim device graph memory"); } -/* static */ absl::StatusOr GpuDriver::StreamIsCapturing( - CUstream stream) { +absl::StatusOr GpuDriver::StreamIsCapturing(CUstream stream) { VLOG(2) << "Checking if stream " << stream << " is capturing"; CUstreamCaptureStatus status; - RETURN_IF_CUDA_RES_ERROR(cuStreamIsCapturing(stream, &status), - "Failed to check stream capturing status"); + TF_RETURN_IF_ERROR(cuda::ToStatus(cuStreamIsCapturing(stream, &status), + "Failed to check stream capturing status")); return status == CU_STREAM_CAPTURE_STATUS_ACTIVE; } -/* static */ absl::Status GpuDriver::GraphConditionalHandleCreate( +absl::Status GpuDriver::GraphConditionalHandleCreate( GpuGraphConditionalHandle* handle, CUgraph graph, GpuContext* context, unsigned int default_launch_value, unsigned int flags) { VLOG(2) << "Create conditional handle for a graph " << graph @@ -755,7 +697,7 @@ GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) { << "; flags: " << flags; #if CUDA_VERSION >= 12030 - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuGraphConditionalHandleCreate(handle, graph, context->context(), default_launch_value, flags), "Failed to create conditional handle for a CUDA graph"); @@ -763,7 +705,6 @@ GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) { return absl::UnimplementedError( "CUDA graph conditional nodes are not implemented"); #endif // CUDA_VERSION >= 12030 - return absl::OkStatus(); } static std::string ConditionalTypeToString( @@ -776,10 +717,9 @@ static std::string ConditionalTypeToString( } } -/* static */ absl::StatusOr -GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, - absl::Span deps, - const GpuGraphNodeParams& params) { +absl::StatusOr GpuDriver::GraphAddNode( + CUgraphNode* node, CUgraph graph, absl::Span deps, + const GpuGraphNodeParams& params) { #if CUDA_VERSION >= 12030 // Add conditional node to a graph. if (auto* conditional = std::get_if(¶ms)) { @@ -804,9 +744,9 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, break; } - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuGraphAddNode(node, graph, deps.data(), deps.size(), &cu_params), - "Failed to add conditional node to a CUDA graph"); + "Failed to add conditional node to a CUDA graph")); GpuGraphConditionalNodeParams::Result result; result.graph = cu_params.conditional.phGraph_out[0]; @@ -819,18 +759,16 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, return absl::UnimplementedError("unsupported node type"); } -/* static */ absl::Status GpuDriver::GraphAddEmptyNode( - CUgraphNode* node, CUgraph graph, absl::Span deps) { +absl::Status GpuDriver::GraphAddEmptyNode(CUgraphNode* node, CUgraph graph, + absl::Span deps) { VLOG(2) << "Add empty node to a graph " << graph << "; deps: " << deps.size(); - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuGraphAddEmptyNode(node, graph, deps.data(), deps.size()), "Failed to add empty node to a CUDA graph"); - - return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::GraphAddKernelNode( +absl::Status GpuDriver::GraphAddKernelNode( CUgraphNode* node, CUgraph graph, absl::Span deps, absl::string_view kernel_name, CUfunction function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, @@ -861,18 +799,16 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, // should be moved one level up to se::Kernel level, and done just once (or // updated once we get a new larger shared memory request). if (shared_mem_bytes != 0) { - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuFuncSetAttribute(function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_mem_bytes), - "Failed to set shared memory size"); + "Failed to set shared memory size")); } - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuGraphAddKernelNode(node, graph, deps.data(), deps.size(), ¶ms), "Failed to add kernel node to a CUDA graph"); - - return absl::OkStatus(); } /*static*/ absl::Status GpuDriver::GraphExecKernelNodeSetParams( @@ -905,17 +841,15 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph, // should be moved one level up to se::Kernel level, and done just once (or // updated once we get a new larger shared memory request). if (shared_mem_bytes != 0) { - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuFuncSetAttribute(function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_mem_bytes), - "Failed to set shared memory size"); + "Failed to set shared memory size")); } - RETURN_IF_CUDA_RES_ERROR(cuGraphExecKernelNodeSetParams(exec, node, ¶ms), - "Failed to set CUDA graph kernel node params"); - - return absl::OkStatus(); + return cuda::ToStatus(cuGraphExecKernelNodeSetParams(exec, node, ¶ms), + "Failed to set CUDA graph kernel node params"); } static CUmemAccess_flags ToCudaMemAccessFlags( @@ -995,9 +929,9 @@ static CUmemAllocationType ToCudaAllocationType( params.accessDescs = &mem_desc; params.poolProps = mem_pool_props; - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuGraphAddMemAllocNode(node, graph, deps.data(), deps.size(), ¶ms), - "Failed to add memory allocation node to a CUDA graph"); + "Failed to add memory allocation node to a CUDA graph")); VLOG(2) << "Add MemAllocNode to a graph " << graph << " size " << size << " address " << reinterpret_cast(params.dptr); @@ -1009,21 +943,21 @@ static CUmemAllocationType ToCudaAllocationType( /*static*/ absl::StatusOr> GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { CUDA_MEM_ALLOC_NODE_PARAMS params; - RETURN_IF_CUDA_RES_ERROR(cuGraphMemAllocNodeGetParams(node, ¶ms), - "Failed to get memory allocation node parameter"); + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuGraphMemAllocNodeGetParams(node, ¶ms), + "Failed to get memory allocation node parameter")); return std::pair{params.dptr, params.bytesize}; } /*static*/ absl::Status GpuDriver::GraphAddMemFreeNode( CUgraphNode* node, CUgraph graph, absl::Span deps, CUdeviceptr gpu_dst) { - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuGraphAddMemFreeNode(node, graph, deps.data(), deps.size(), gpu_dst), "Failed to add memory free node to a CUDA graph"); - return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::GraphAddMemcpyD2DNode( +absl::Status GpuDriver::GraphAddMemcpyD2DNode( GpuContext* context, CUgraphNode* node, CUgraph graph, absl::Span deps, CUdeviceptr gpu_dst, CUdeviceptr gpu_src, uint64_t size) { @@ -1043,15 +977,13 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { params.Height = 1; params.Depth = 1; - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuGraphAddMemcpyNode(node, graph, deps.data(), deps.size(), ¶ms, context->context()), "Failed to add memcpy d2d node to a CUDA graph"); - - return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::GraphExecMemcpyD2DNodeSetParams( +absl::Status GpuDriver::GraphExecMemcpyD2DNodeSetParams( GpuContext* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size) { VLOG(2) << "Set memcpy d2d node params " << node << " in graph executable " @@ -1070,11 +1002,9 @@ GpuDriver::GraphGetMemAllocNodeParams(CUgraphNode node) { params.Height = 1; params.Depth = 1; - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuGraphExecMemcpyNodeSetParams(exec, node, ¶ms, context->context()), "Failed to set memcpy d2d node params"); - - return absl::OkStatus(); } namespace { @@ -1109,7 +1039,7 @@ struct BitPatternToValue { } // namespace -/* static */ absl::Status GpuDriver::GraphAddMemsetNode( +absl::Status GpuDriver::GraphAddMemsetNode( GpuContext* context, CUgraphNode* node, GpuGraphHandle graph, absl::Span deps, CUdeviceptr dst, std::variant bit_pattern, @@ -1132,15 +1062,13 @@ struct BitPatternToValue { params.value = value; params.width = num_elements; - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuGraphAddMemsetNode(node, graph, deps.data(), deps.size(), ¶ms, context->context()), "Failed to add memset node to a CUDA graph"); - - return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::GraphExecMemsetNodeSetParams( +absl::Status GpuDriver::GraphExecMemsetNodeSetParams( GpuContext* context, CUgraphExec exec, CUgraphNode node, CUdeviceptr dst, std::variant bit_pattern, uint64_t num_elements) { @@ -1162,24 +1090,20 @@ struct BitPatternToValue { params.value = value; params.width = num_elements; - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuGraphExecMemsetNodeSetParams(exec, node, ¶ms, context->context()), "Failed to set memset node params"); - - return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::GraphAddChildNode( - CUgraphNode* node, CUgraph graph, absl::Span deps, - CUgraph child) { +absl::Status GpuDriver::GraphAddChildNode(CUgraphNode* node, CUgraph graph, + absl::Span deps, + CUgraph child) { VLOG(2) << "Create a new node by cloning the child graph " << child << " and add it to " << graph << "; deps: " << deps.size(); - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuGraphAddChildGraphNode(node, graph, deps.data(), deps.size(), child), "Failed to create a child graph node and add it to a CUDA graph"); - - return absl::OkStatus(); } /*static*/ absl::Status GpuDriver::GraphExecChildNodeSetParams(CUgraphExec exec, @@ -1188,14 +1112,11 @@ struct BitPatternToValue { VLOG(2) << "Set child node params " << node << " in graph executable " << exec << "to params contained in " << child; - RETURN_IF_CUDA_RES_ERROR( - cuGraphExecChildGraphNodeSetParams(exec, node, child), - "Failed to set CUDA graph child node params"); - - return absl::OkStatus(); + return cuda::ToStatus(cuGraphExecChildGraphNodeSetParams(exec, node, child), + "Failed to set CUDA graph child node params"); } -/* static */ absl::Status GpuDriver::LaunchKernel( +absl::Status GpuDriver::LaunchKernel( GpuContext* context, absl::string_view kernel_name, CUfunction function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, @@ -1212,26 +1133,25 @@ struct BitPatternToValue { // should be moved one level up to se::Kernel level, and done just once (or // updated once we get a new larger shared memory request). if (shared_mem_bytes != 0) { - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuFuncSetAttribute(function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_mem_bytes), - "Failed to set shared memory size"); + "Failed to set shared memory size")); } - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuLaunchKernel(function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem_bytes, stream, kernel_params, extra), - "Failed to launch CUDA kernel: ", kernel_name, - "; block dims: ", block_dim_x, "x", block_dim_y, "x", block_dim_z, - "; grid dims: ", grid_dim_x, "x", grid_dim_y, "x", grid_dim_z, - "; shared memory size: ", shared_mem_bytes); - - return absl::OkStatus(); + absl::StrCat("Failed to launch CUDA kernel: ", kernel_name, + "; block dims: ", block_dim_x, "x", block_dim_y, "x", + block_dim_z, "; grid dims: ", grid_dim_x, "x", grid_dim_y, + "x", grid_dim_z, + "; shared memory size: ", shared_mem_bytes)); } -/* static */ absl::Status GpuDriver::LaunchKernel( +absl::Status GpuDriver::LaunchKernel( GpuContext* context, absl::string_view kernel_name, GpuFunctionHandle function, unsigned int cluster_dim_x, unsigned int cluster_dim_y, unsigned int cluster_dim_z, @@ -1251,11 +1171,11 @@ struct BitPatternToValue { // should be moved one level up to se::Kernel level, and done just once (or // updated once we get a new larger shared memory request). if (shared_mem_bytes != 0) { - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuFuncSetAttribute(function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_mem_bytes), - "Failed to set shared memory size"); + "Failed to set shared memory size")); } CUlaunchConfig launch_config; @@ -1279,30 +1199,26 @@ struct BitPatternToValue { launch_config.attrs = &cluster_dims; launch_config.numAttrs = 1; - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuLaunchKernelEx(&launch_config, function, kernel_params, extra), - "Failed to launch CUDA kernel: ", kernel_name, - "; cluster dims: ", cluster_dim_x, "x", cluster_dim_y, "x", cluster_dim_z, - "; block dims: ", block_dim_x, "x", block_dim_y, "x", block_dim_z, - "; grid dims: ", grid_dim_x, "x", grid_dim_y, "x", grid_dim_z, - "; shared memory size: ", shared_mem_bytes); - - return absl::OkStatus(); + absl::StrCat("Failed to launch CUDA kernel: ", kernel_name, + "; cluster dims: ", cluster_dim_x, "x", cluster_dim_y, "x", + cluster_dim_z, "; block dims: ", block_dim_x, "x", + block_dim_y, "x", block_dim_z, "; grid dims: ", grid_dim_x, + "x", grid_dim_y, "x", grid_dim_z, + "; shared memory size: ", shared_mem_bytes)); } -/* static */ absl::Status GpuDriver::LoadCubin(GpuContext* context, - const char* cubin_bytes, - CUmodule* module) { +absl::Status GpuDriver::LoadCubin(GpuContext* context, const char* cubin_bytes, + CUmodule* module) { ScopedActivateContext activation(context); - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuModuleLoadFatBinary(module, cubin_bytes), "Failed to load in-memory CUBIN (compiled for a different GPU?)."); - return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::LoadPtx(GpuContext* context, - const char* ptx_contents, - CUmodule* module) { +absl::Status GpuDriver::LoadPtx(GpuContext* context, const char* ptx_contents, + CUmodule* module) { absl::Notification notification; absl::Status ret = absl::OkStatus(); GetDriverExecutor()->Schedule([context, ptx_contents, module, &ret, @@ -1329,13 +1245,13 @@ struct BitPatternToValue { absl::bit_cast(uintptr_t(log_verbose))}; CHECK(TF_ARRAYSIZE(options) == TF_ARRAYSIZE(option_values)); - CUresult res; + absl::Status status; { // TODO(leary) Need to see if NVIDIA can expunge the leakiness in their // module loading: see http://b/13248943 absl::LeakCheckDisabler disabler; - res = cuModuleLoadDataEx(module, ptx_data, TF_ARRAYSIZE(options), options, - option_values); + status = cuda::ToStatus(cuModuleLoadDataEx( + module, ptx_data, TF_ARRAYSIZE(options), options, option_values)); } // The PTX JIT mutates the values in the option values array to reflect the @@ -1346,8 +1262,8 @@ struct BitPatternToValue { CHECK_LE(error_log_buffer_bytes, kLogBufferBytesLimit); CHECK_LE(info_log_buffer_bytes, kLogBufferBytesLimit); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to load PTX text as a module: " << ToString(res); + if (!status.ok()) { + LOG(ERROR) << "failed to load PTX text as a module: " << status; // As a precaution for null termination of the API-provided value, ensure // that at least the last byte is null. error_log_buffer[error_log_buffer_bytes ? error_log_buffer_bytes - 1 @@ -1359,10 +1275,9 @@ struct BitPatternToValue { ret = absl::ResourceExhaustedError( absl::StrFormat("Failed to load PTX text as a module (register " "allocation failed): %s", - ToString(res))); + status.ToString())); } else { - ret = absl::InternalError(absl::StrFormat( - "Failed to load PTX text as a module: %s", ToString(res))); + ret = status; } notification.Notify(); return; @@ -1380,67 +1295,64 @@ struct BitPatternToValue { return ret; } -/* static */ absl::Status GpuDriver::LoadHsaco(GpuContext* context, - const char* hsaco_contents, - CUmodule* module) { +absl::Status GpuDriver::LoadHsaco(GpuContext* context, + const char* hsaco_contents, + CUmodule* module) { return absl::InternalError( "Feature not supported on CUDA platform (LoadHsaco)"); } -/* static */ absl::Status GpuDriver::SynchronousMemsetUint8( - GpuContext* context, CUdeviceptr location, uint8_t value, size_t size) { +absl::Status GpuDriver::SynchronousMemsetUint8(GpuContext* context, + CUdeviceptr location, + uint8_t value, size_t size) { ScopedActivateContext activation(context); - RETURN_IF_CUDA_RES_ERROR(cuMemsetD8(location, value, size), - "Failed to memset memory"); - return absl::OkStatus(); + return cuda::ToStatus(cuMemsetD8(location, value, size), + "Failed to memset memory"); } -/* static */ absl::Status GpuDriver::SynchronousMemsetUint32( - GpuContext* context, CUdeviceptr location, uint32_t value, - size_t uint32_count) { +absl::Status GpuDriver::SynchronousMemsetUint32(GpuContext* context, + CUdeviceptr location, + uint32_t value, + size_t uint32_count) { ScopedActivateContext activation(context); - RETURN_IF_CUDA_RES_ERROR(cuMemsetD32(location, value, uint32_count), - "Failed to memset memory"); - return absl::OkStatus(); + return cuda::ToStatus(cuMemsetD32(location, value, uint32_count), + "Failed to memset memory"); } -/* static */ absl::Status GpuDriver::AsynchronousMemsetUint8( - GpuContext* context, CUdeviceptr location, uint8_t value, - size_t uint32_count, CUstream stream) { +absl::Status GpuDriver::AsynchronousMemsetUint8(GpuContext* context, + CUdeviceptr location, + uint8_t value, + size_t uint32_count, + CUstream stream) { ScopedActivateContext activation(context); - RETURN_IF_CUDA_RES_ERROR( - cuMemsetD8Async(location, value, uint32_count, stream), - "Failed to enqueue async memset operation"); - return absl::OkStatus(); + return cuda::ToStatus(cuMemsetD8Async(location, value, uint32_count, stream), + "Failed to enqueue async memset operation"); } -/* static */ absl::Status GpuDriver::AsynchronousMemsetUint32( - GpuContext* context, CUdeviceptr location, uint32_t value, - size_t uint32_count, CUstream stream) { +absl::Status GpuDriver::AsynchronousMemsetUint32(GpuContext* context, + CUdeviceptr location, + uint32_t value, + size_t uint32_count, + CUstream stream) { ScopedActivateContext activation(context); - RETURN_IF_CUDA_RES_ERROR( - cuMemsetD32Async(location, value, uint32_count, stream), - "Failed to enqueue async memset operation"); - return absl::OkStatus(); + return cuda::ToStatus(cuMemsetD32Async(location, value, uint32_count, stream), + "Failed to enqueue async memset operation"); } -/* static */ bool GpuDriver::AddStreamCallback(GpuContext* context, - CUstream stream, - StreamCallback callback, - void* data) { +bool GpuDriver::AddStreamCallback(GpuContext* context, CUstream stream, + StreamCallback callback, void* data) { // Note: flags param is required to be zero according to CUDA 6.0. - CUresult res = cuLaunchHostFunc(stream, callback, data); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "unable to add host callback: " << ToString(res); + auto status = cuda::ToStatus(cuLaunchHostFunc(stream, callback, data)); + if (!status.ok()) { + LOG(ERROR) << "unable to add host callback: " << status; return false; } return true; } -/* static */ absl::Status GpuDriver::GetModuleFunction(GpuContext* context, - CUmodule module, - const char* kernel_name, - CUfunction* function) { +absl::Status GpuDriver::GetModuleFunction(GpuContext* context, CUmodule module, + const char* kernel_name, + CUfunction* function) { ScopedActivateContext activated{context}; CHECK(module != nullptr && kernel_name != nullptr); cudaError_t cuda_error = cudaPeekAtLastError(); @@ -1450,63 +1362,57 @@ struct BitPatternToValue { cuda_error, "): ", cudaGetErrorName(cuda_error), " : ", cudaGetErrorString(cuda_error))); } - RETURN_IF_CUDA_RES_ERROR(cuModuleGetFunction(function, module, kernel_name), - "Failed to get module function"); - return absl::OkStatus(); + return cuda::ToStatus(cuModuleGetFunction(function, module, kernel_name), + "Failed to get module function"); } -/* static */ absl::Status GpuDriver::GetModuleSymbol(GpuContext* context, - CUmodule module, - const char* symbol_name, - CUdeviceptr* dptr, - size_t* bytes) { +absl::Status GpuDriver::GetModuleSymbol(GpuContext* context, CUmodule module, + const char* symbol_name, + CUdeviceptr* dptr, size_t* bytes) { ScopedActivateContext activated{context}; CHECK(module != nullptr && symbol_name != nullptr && (dptr != nullptr || bytes != nullptr)); - RETURN_IF_CUDA_RES_ERROR( + return cuda::ToStatus( cuModuleGetGlobal(dptr, bytes, module, symbol_name), absl::StrCat("Failed to get symbol '", symbol_name, "'")); - return absl::OkStatus(); } -/* static */ void GpuDriver::UnloadModule(GpuContext* context, - CUmodule module) { +void GpuDriver::UnloadModule(GpuContext* context, CUmodule module) { ScopedActivateContext activated{context}; - CUresult res = cuModuleUnload(module); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuModuleUnload(module)); + if (!status.ok()) { LOG(ERROR) << "failed to unload module " << module - << "; leaking: " << ToString(res); + << "; leaking: " << status; } } -/* static */ absl::StatusOr GpuDriver::DeviceFromContext( - GpuContext* context) { +absl::StatusOr GpuDriver::DeviceFromContext(GpuContext* context) { ScopedActivateContext activated{context}; CUdevice device = -1; - CUresult result = cuCtxGetDevice(&device); - if (result == CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuCtxGetDevice(&device)); + if (status.ok()) { return device; } - return absl::InternalError( - absl::StrCat("failed to get device for context: ", ToString(result))); + return status; } -/* static */ bool GpuDriver::CreateStream(GpuContext* context, CUstream* stream, - int priority) { +bool GpuDriver::CreateStream(GpuContext* context, CUstream* stream, + int priority) { ScopedActivateContext activated{context}; - CUresult res; + absl::Status status; // If the priority is 0, then use the previous api to create the stream with // the default priority for backward compatibility. Probably there is no // difference in using the new api call but leaving it as is for now. if (priority == 0) { - res = cuStreamCreate(stream, CU_STREAM_NON_BLOCKING); + status = cuda::ToStatus(cuStreamCreate(stream, CU_STREAM_NON_BLOCKING)); } else { - res = cuStreamCreateWithPriority(stream, CU_STREAM_NON_BLOCKING, priority); + status = cuda::ToStatus( + cuStreamCreateWithPriority(stream, CU_STREAM_NON_BLOCKING, priority)); } - if (res != CUDA_SUCCESS) { + if (!status.ok()) { LOG(ERROR) << "could not allocate CUDA stream for context " - << context->context() << ": " << ToString(res); + << context->context() << ": " << status; return false; } @@ -1515,17 +1421,16 @@ struct BitPatternToValue { return true; } -/* static */ void GpuDriver::DestroyStream(GpuContext* context, - CUstream* stream) { +void GpuDriver::DestroyStream(GpuContext* context, CUstream* stream) { if (*stream == nullptr) { return; } ScopedActivateContext activated{context}; - CUresult res = cuStreamDestroy(*stream); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuStreamDestroy(*stream)); + if (!status.ok()) { LOG(ERROR) << "failed to destroy CUDA stream for context " - << context->context() << ": " << ToString(res); + << context->context() << ": " << status; } else { VLOG(2) << "successfully destroyed stream " << *stream << " for context " << context->context(); @@ -1533,21 +1438,20 @@ struct BitPatternToValue { } } -/* static */ void* GpuDriver::DeviceAllocate(GpuContext* context, - uint64_t bytes) { +void* GpuDriver::DeviceAllocate(GpuContext* context, uint64_t bytes) { if (bytes == 0) { return nullptr; } ScopedActivateContext activated{context}; CUdeviceptr result = 0; - CUresult res = cuMemAlloc(&result, bytes); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuMemAlloc(&result, bytes)); + if (!status.ok()) { // LOG(INFO) because this isn't always important to users (e.g. BFCAllocator // implements a retry if the first allocation fails). LOG(INFO) << "failed to allocate " << tsl::strings::HumanReadableNumBytes(bytes) << " (" << bytes - << " bytes) from device: " << ToString(res); + << " bytes) from device: " << status; return nullptr; } void* ptr = reinterpret_cast(result); @@ -1556,29 +1460,28 @@ struct BitPatternToValue { return ptr; } -/* static */ void GpuDriver::DeviceDeallocate(GpuContext* context, - void* location) { +void GpuDriver::DeviceDeallocate(GpuContext* context, void* location) { ScopedActivateContext activation(context); CUdeviceptr pointer = absl::bit_cast(location); - CUresult res = cuMemFree(pointer); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuMemFree(pointer)); + if (!status.ok()) { LOG(ERROR) << "failed to free device memory at " << location - << "; result: " << ToString(res); + << "; result: " << status; } else { VLOG(2) << "deallocated " << location << " for context " << context->context(); } } -/* static */ void* GpuDriver::UnifiedMemoryAllocate(GpuContext* context, - uint64_t bytes) { +void* GpuDriver::UnifiedMemoryAllocate(GpuContext* context, uint64_t bytes) { ScopedActivateContext activation(context); CUdeviceptr result = 0; // "Portable" memory is visible to all CUDA contexts. Safe for our use model. - CUresult res = cuMemAllocManaged(&result, bytes, CU_MEM_ATTACH_GLOBAL); - if (res != CUDA_SUCCESS) { + auto status = + cuda::ToStatus(cuMemAllocManaged(&result, bytes, CU_MEM_ATTACH_GLOBAL)); + if (!status.ok()) { LOG(ERROR) << "failed to alloc " << bytes - << " bytes unified memory; result: " << ToString(res); + << " bytes unified memory; result: " << status; return nullptr; } void* ptr = reinterpret_cast(result); @@ -1587,78 +1490,74 @@ struct BitPatternToValue { return ptr; } -/* static */ void GpuDriver::UnifiedMemoryDeallocate(GpuContext* context, - void* location) { +void GpuDriver::UnifiedMemoryDeallocate(GpuContext* context, void* location) { ScopedActivateContext activation(context); CUdeviceptr pointer = absl::bit_cast(location); - CUresult res = cuMemFree(pointer); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuMemFree(pointer)); + if (!status.ok()) { LOG(ERROR) << "failed to free unified memory at " << location - << "; result: " << ToString(res); + << "; result: " << status; } else { VLOG(2) << "deallocated unified memory at " << location << " for context " << context->context(); } } -/* static */ void* GpuDriver::HostAllocate(GpuContext* context, - uint64_t bytes) { +void* GpuDriver::HostAllocate(GpuContext* context, uint64_t bytes) { ScopedActivateContext activation(context); void* host_mem = nullptr; // "Portable" memory is visible to all CUDA contexts. Safe for our use model. - CUresult res = cuMemHostAlloc(&host_mem, bytes, CU_MEMHOSTALLOC_PORTABLE); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to alloc " << bytes - << " bytes on host: " << ToString(res); + auto status = cuda::ToStatus( + cuMemHostAlloc(&host_mem, bytes, CU_MEMHOSTALLOC_PORTABLE)); + if (!status.ok()) { + LOG(ERROR) << "failed to alloc " << bytes << " bytes on host: " << status; } return host_mem; } -/* static */ void GpuDriver::HostDeallocate(GpuContext* context, - void* location) { +void GpuDriver::HostDeallocate(GpuContext* context, void* location) { ScopedActivateContext activation(context); - CUresult res = cuMemFreeHost(location); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuMemFreeHost(location)); + if (!status.ok()) { LOG(ERROR) << "error deallocating host memory at " << location << ": " - << ToString(res); + << status; } } -/* static */ bool GpuDriver::HostRegister(GpuContext* context, void* location, - uint64_t bytes) { +bool GpuDriver::HostRegister(GpuContext* context, void* location, + uint64_t bytes) { ScopedActivateContext activation(context); // "Portable" memory is visible to all CUDA contexts. Safe for our use model. - CUresult res = - cuMemHostRegister(location, bytes, CU_MEMHOSTREGISTER_PORTABLE); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus( + cuMemHostRegister(location, bytes, CU_MEMHOSTREGISTER_PORTABLE)); + if (!status.ok()) { LOG(ERROR) << "error registering host memory at " << location << ": " - << ToString(res); + << status; return false; } return true; } -/* static */ bool GpuDriver::HostUnregister(GpuContext* context, - void* location) { +bool GpuDriver::HostUnregister(GpuContext* context, void* location) { ScopedActivateContext activation(context); - CUresult res = cuMemHostUnregister(location); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuMemHostUnregister(location)); + if (!status.ok()) { LOG(ERROR) << "error unregistering host memory at " << location << ": " - << ToString(res); + << status; return false; } return true; } -/* static */ int GpuDriver::GetGpuStreamPriority( +int GpuDriver::GetGpuStreamPriority( GpuContext* context, stream_executor::StreamPriority stream_priority) { ScopedActivateContext activation(context); if (stream_priority == stream_executor::StreamPriority::Default) { return 0; } int lowest, highest; - CUresult res = cuCtxGetStreamPriorityRange(&lowest, &highest); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuCtxGetStreamPriorityRange(&lowest, &highest)); + if (!status.ok()) { LOG(ERROR) << "Could not query stream priority range. Returning default priority."; return 0; @@ -1667,31 +1566,28 @@ struct BitPatternToValue { : lowest; } -/* static */ absl::StatusOr -GpuDriver::ReserveVirtualMemory(GpuContext* context, uint64_t bytes) { +absl::StatusOr GpuDriver::ReserveVirtualMemory( + GpuContext* context, uint64_t bytes) { ScopedActivateContext activation(context); CUdeviceptr base; - CUresult res = cuMemAddressReserve(&base, bytes, /*alignment=*/0, - /*addr=*/0, /*flags=*/0); - if (res != CUDA_SUCCESS) { - return absl::InternalError( - absl::StrFormat("error reserving %d bytes of virtual GPU memory: %s", - bytes, ToString(res))); - } - return {{base, bytes}}; + return cuda::ToStatus( + cuMemAddressReserve(&base, bytes, /*alignment=*/0, + /*addr=*/0, /*flags=*/0), + absl::StrFormat("error reserving %d bytes of virtual GPU memory", bytes)); } -/* static */ void GpuDriver::FreeVirtualMemory( - GpuContext* context, GpuDriver::VmemSpan reservation) { +void GpuDriver::FreeVirtualMemory(GpuContext* context, + GpuDriver::VmemSpan reservation) { ScopedActivateContext activation(context); - CUresult res = cuMemAddressFree(reservation.base, reservation.size_bytes); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus( + cuMemAddressFree(reservation.base, reservation.size_bytes)); + if (!status.ok()) { LOG(ERROR) << "error freeing vmem reservation of size " << reservation.size_bytes << " at address " << reservation.base; } } -/* static */ absl::StatusOr GpuDriver::GetMinAllocationGranularity( +absl::StatusOr GpuDriver::GetMinAllocationGranularity( GpuDeviceHandle device) { CUmemAllocationProp props = {}; props.type = CU_MEM_ALLOCATION_TYPE_PINNED; @@ -1699,17 +1595,15 @@ GpuDriver::ReserveVirtualMemory(GpuContext* context, uint64_t bytes) { props.location.id = device; size_t granularity; - CUresult res = cuMemGetAllocationGranularity( - &granularity, &props, CU_MEM_ALLOC_GRANULARITY_MINIMUM); - if (res != CUDA_SUCCESS) { - return absl::InternalError(absl::StrCat( - "failed to get min allocation granularity: ", ToString(res))); - } + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuMemGetAllocationGranularity(&granularity, &props, + CU_MEM_ALLOC_GRANULARITY_MINIMUM), + "failed to get min allocation granularity")); return granularity; } -/* static */ absl::StatusOr -GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { +absl::StatusOr GpuDriver::CreateMemoryHandle( + GpuContext* context, uint64_t bytes) { ScopedActivateContext activation(context); auto device = DeviceFromContext(context); if (!device.ok()) { @@ -1723,27 +1617,24 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { props.location.id = device.value(); CUmemGenericAllocationHandle mem_handle; - CUresult res = cuMemCreate(&mem_handle, bytes, &props, 0); - if (res != CUDA_SUCCESS) { - return absl::InternalError( - absl::StrFormat("failed to create memory allocation of size %d: %s", - bytes, ToString(res))); - } + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuMemCreate(&mem_handle, bytes, &props, 0), + absl::StrFormat("failed to create memory allocation of size %d", bytes))); return GpuDriver::GenericMemoryHandle{mem_handle, bytes}; } -/* static */ void GpuDriver::ReleaseMemoryHandle( - GpuContext* context, GpuDriver::GenericMemoryHandle handle) { +void GpuDriver::ReleaseMemoryHandle(GpuContext* context, + GpuDriver::GenericMemoryHandle handle) { ScopedActivateContext activation(context); - CUresult res = cuMemRelease(handle.handle); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuMemRelease(handle.handle)); + if (!status.ok()) { LOG(ERROR) << "Failed to release memory handle " << handle.handle - << " of size " << handle.bytes << ": " << ToString(res); + << " of size " << handle.bytes << ": " << status; } } -/* static */ absl::Status GpuDriver::MapMemory( +absl::Status GpuDriver::MapMemory( GpuContext* context, CUdeviceptr va, const GpuDriver::GenericMemoryHandle& handle, const std::vector& device_handles) { @@ -1755,12 +1646,8 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { } // NB: Zero is the only valid value for both flags and offset. - CUresult res = - cuMemMap(va, handle.bytes, /*offset=*/0, handle.handle, /*flags=*/0); - if (res != CUDA_SUCCESS) { - return absl::InternalError(absl::StrFormat( - "Failed to map %d bytes at %d: %s", handle.bytes, va, ToString(res))); - } + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuMemMap(va, handle.bytes, /*offset=*/0, handle.handle, /*flags=*/0))); std::vector access_descriptors(device_handles.size()); for (int i = 0; i < access_descriptors.size(); ++i) { @@ -1769,91 +1656,85 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { access_descriptors[i].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; } - res = cuMemSetAccess(va, handle.bytes, access_descriptors.data(), - access_descriptors.size()); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuMemSetAccess( + va, handle.bytes, access_descriptors.data(), access_descriptors.size())); + if (!status.ok()) { // Unmap the memory that we failed to set access for. - if (cuMemUnmap(va, handle.bytes) != CUDA_SUCCESS) { + if (!cuda::ToStatus(cuMemUnmap(va, handle.bytes)).ok()) { LOG(ERROR) << "Failed to unmap memory in GpuDriver::MapMemory error path."; } - return absl::InternalError(absl::StrFormat( - "Failed to set read/write access on memory mapped at %d: %s", va, - ToString(res))); + return status; } return absl::OkStatus(); } -/* static */ void GpuDriver::UnmapMemory(GpuContext* context, CUdeviceptr va, - uint64_t bytes) { +void GpuDriver::UnmapMemory(GpuContext* context, CUdeviceptr va, + uint64_t bytes) { ScopedActivateContext activation(context); - CUresult res = cuMemUnmap(va, bytes); - if (res != CUDA_SUCCESS) { + auto status = cuda::ToStatus(cuMemUnmap(va, bytes)); + if (!status.ok()) { LOG(ERROR) << "Failed to unmap memory at " << va << " of size " << bytes - << ": " << ToString(res); + << ": " << status; } } -/* static */ absl::Status GpuDriver::DestroyEvent(GpuContext* context, - CUevent* event) { +absl::Status GpuDriver::DestroyEvent(GpuContext* context, CUevent* event) { if (*event == nullptr) { return absl::InvalidArgumentError("input event cannot be null"); } ScopedActivateContext activated{context}; - RETURN_IF_CUDA_RES_ERROR(cuEventDestroy(*event), - "Error destroying CUDA event"); - return absl::OkStatus(); + return cuda::ToStatus(cuEventDestroy(*event), "Error destroying CUDA event"); } -/* static */ absl::Status GpuDriver::RecordEvent(GpuContext* context, - CUevent event, - CUstream stream) { +absl::Status GpuDriver::RecordEvent(GpuContext* context, CUevent event, + CUstream stream) { ScopedActivateContext activated{context}; - RETURN_IF_CUDA_RES_ERROR(cuEventRecord(event, stream), - "Error recording CUDA event"); - return absl::OkStatus(); + return cuda::ToStatus(cuEventRecord(event, stream), + "Error recording CUDA event"); } -/* static */ bool GpuDriver::GetEventElapsedTime(GpuContext* context, - float* elapsed_milliseconds, - CUevent start, CUevent stop) { +bool GpuDriver::GetEventElapsedTime(GpuContext* context, + float* elapsed_milliseconds, CUevent start, + CUevent stop) { ScopedActivateContext activated{context}; // The stop event must have completed in order for cuEventElapsedTime to // work. - CUresult res = cuEventSynchronize(stop); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res); + auto status = cuda::ToStatus(cuEventSynchronize(stop)); + if (!status.ok()) { + LOG(ERROR) << "failed to synchronize the stop event: " << status; return false; } - res = cuEventElapsedTime(elapsed_milliseconds, start, stop); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to get elapsed time between events: " - << ToString(res); + status = + cuda::ToStatus(cuEventElapsedTime(elapsed_milliseconds, start, stop)); + if (!status.ok()) { + LOG(ERROR) << "failed to get elapsed time between events: " << status; return false; } return true; } -/* static */ bool GpuDriver::WaitStreamOnEvent(GpuContext* context, - CUstream stream, CUevent event) { +bool GpuDriver::WaitStreamOnEvent(GpuContext* context, CUstream stream, + CUevent event) { ScopedActivateContext activation(context); - CUresult res = cuStreamWaitEvent(stream, event, 0 /* = flags */); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "could not wait stream on event: " << ToString(res); + auto status = + cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */)); + if (!status.ok()) { + LOG(ERROR) << "could not wait stream on event: " << status; return false; } return true; } -/* static */ bool GpuDriver::SynchronizeContext(GpuContext* context) { +bool GpuDriver::SynchronizeContext(GpuContext* context) { ScopedActivateContext activation(context); - CUresult res = cuCtxSynchronize(); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "could not synchronize on CUDA context: " << ToString(res) + auto status = cuda::ToStatus(cuCtxSynchronize()); + if (!status.ok()) { + LOG(ERROR) << "could not synchronize on CUDA context: " << status << " :: " << tsl::CurrentStackTrace(); return false; } @@ -1861,17 +1742,15 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return true; } -/* static */ absl::Status GpuDriver::SynchronizeStream(GpuContext* context, - CUstream stream) { +absl::Status GpuDriver::SynchronizeStream(GpuContext* context, + CUstream stream) { ScopedActivateContext activated{context}; CHECK(stream != nullptr); - RETURN_IF_CUDA_RES_ERROR(cuStreamSynchronize(stream), - "Could not synchronize CUDA stream"); - return absl::OkStatus(); + return cuda::ToStatus(cuStreamSynchronize(stream), + "Could not synchronize CUDA stream"); } -/* static */ bool GpuDriver::IsStreamIdle(GpuContext* context, - CUstream stream) { +bool GpuDriver::IsStreamIdle(GpuContext* context, CUstream stream) { ScopedActivateContext activated{context}; CHECK(stream != nullptr); CUresult res = cuStreamQuery(stream); @@ -1880,45 +1759,46 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { } if (res != CUDA_ERROR_NOT_READY) { - LOG(ERROR) << "stream in bad state on status query: " << ToString(res); + LOG(ERROR) << "stream in bad state on status query: " + << cuda::ToStatus(res); } return false; } -/* static */ absl::Status GpuDriver::SynchronousMemcpyD2H(GpuContext* context, - void* host_dst, - CUdeviceptr gpu_src, - uint64_t size) { +absl::Status GpuDriver::SynchronousMemcpyD2H(GpuContext* context, + void* host_dst, + CUdeviceptr gpu_src, + uint64_t size) { ScopedActivateContext activation(context); - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuMemcpyDtoH(host_dst, gpu_src, size), absl::StrFormat("failed to synchronous memcpy from device to host " "host dst: %p; GPU src: %p; size: %u=0x%x", - host_dst, absl::bit_cast(gpu_src), size, size)); + host_dst, absl::bit_cast(gpu_src), size, size))); VLOG(2) << "successfully sync memcpy'd d2h of " << size << " bytes to " << host_dst; return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::SynchronousMemcpyH2D(GpuContext* context, - CUdeviceptr gpu_dst, - const void* host_src, - uint64_t size) { +absl::Status GpuDriver::SynchronousMemcpyH2D(GpuContext* context, + CUdeviceptr gpu_dst, + const void* host_src, + uint64_t size) { ScopedActivateContext activation(context); - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuMemcpyHtoD(gpu_dst, host_src, size), absl::StrFormat( "failed to synchronous memcpy from host to device: GPU dst: %p;" " host src: %p; size: %u=0x%x", - absl::bit_cast(gpu_dst), host_src, size, size)); + absl::bit_cast(gpu_dst), host_src, size, size))); VLOG(2) << "successfully enqueued sync memcpy h2d of " << size << " bytes"; return absl::OkStatus(); } -/* static */ absl::Status GpuDriver::SynchronousMemcpyD2D(GpuContext* context, - CUdeviceptr gpu_dst, - CUdeviceptr gpu_src, - uint64_t size) { +absl::Status GpuDriver::SynchronousMemcpyD2D(GpuContext* context, + CUdeviceptr gpu_dst, + CUdeviceptr gpu_src, + uint64_t size) { ScopedActivateContext activation(context); CUresult result; @@ -1950,29 +1830,29 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { result = cuMemcpyPeer(gpu_dst, dst_context, gpu_src, src_context, size); } - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( result, absl::StrFormat( "failed to synchronous memcpy from host to device: GPU dst: %p; " "GPU src: %p; size: %u=0x%x", absl::bit_cast(gpu_dst), absl::bit_cast(gpu_src), size, - size)); + size))); VLOG(2) << "successfully sync memcpy'd d2d of " << size << " bytes"; return absl::OkStatus(); } -/* static */ bool GpuDriver::AsynchronousMemcpyD2H(GpuContext* context, - void* host_dst, - CUdeviceptr gpu_src, - uint64_t size, - CUstream stream) { +bool GpuDriver::AsynchronousMemcpyD2H(GpuContext* context, void* host_dst, + CUdeviceptr gpu_src, uint64_t size, + CUstream stream) { ScopedActivateContext activation(context); - CUresult res = cuMemcpyDtoHAsync(host_dst, gpu_src, size, stream); - if (res != CUDA_SUCCESS) { + auto status = + cuda::ToStatus(cuMemcpyDtoHAsync(host_dst, gpu_src, size, stream)); + if (!status.ok()) { LOG(ERROR) << absl::StrFormat( "failed to enqueue async memcpy from device to host: %s; host dst: %p; " "GPU src: %p; size: %u=0x%x", - ToString(res), host_dst, absl::bit_cast(gpu_src), size, size); + status.ToString(), host_dst, absl::bit_cast(gpu_src), size, + size); return false; } VLOG(2) << "successfully enqueued async memcpy d2h of " << size @@ -1981,18 +1861,18 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return true; } -/* static */ bool GpuDriver::AsynchronousMemcpyH2D(GpuContext* context, - CUdeviceptr gpu_dst, - const void* host_src, - uint64_t size, - CUstream stream) { +bool GpuDriver::AsynchronousMemcpyH2D(GpuContext* context, CUdeviceptr gpu_dst, + const void* host_src, uint64_t size, + CUstream stream) { ScopedActivateContext activation(context); - CUresult res = cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream); - if (res != CUDA_SUCCESS) { + auto status = + cuda::ToStatus(cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream)); + if (!status.ok()) { LOG(ERROR) << absl::StrFormat( "failed to enqueue async memcpy from host to device: %s; GPU dst: %p; " "host src: %p; size: %u=0x%x", - ToString(res), absl::bit_cast(gpu_dst), host_src, size, size); + status.ToString(), absl::bit_cast(gpu_dst), host_src, size, + size); return false; } VLOG(2) << "successfully enqueued async memcpy h2d of " << size << " bytes" @@ -2001,11 +1881,9 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return true; } -/* static */ bool GpuDriver::AsynchronousMemcpyD2D(GpuContext* context, - CUdeviceptr gpu_dst, - CUdeviceptr gpu_src, - uint64_t size, - CUstream stream) { +bool GpuDriver::AsynchronousMemcpyD2D(GpuContext* context, CUdeviceptr gpu_dst, + CUdeviceptr gpu_src, uint64_t size, + CUstream stream) { ScopedActivateContext activation(context); CUresult result; @@ -2057,7 +1935,7 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { "; GPU dst: %p on %s %s" "; GPU src: %p on %s %s" "; can access? %s; size: %u=0x%x", - ToString(result), absl::bit_cast(gpu_dst), + cuda::ToStatus(result).ToString(), absl::bit_cast(gpu_dst), CUDAPointerToMemorySpaceString(gpu_dst), CUDAPointerToDeviceString(gpu_dst), absl::bit_cast(gpu_src), CUDAPointerToMemorySpaceString(gpu_src), @@ -2072,9 +1950,8 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return true; } -/* static */ absl::Status GpuDriver::InitEvent(GpuContext* context, - CUevent* result, - EventFlags flags) { +absl::Status GpuDriver::InitEvent(GpuContext* context, CUevent* result, + EventFlags flags) { int cuflags; switch (flags) { case EventFlags::kDefault: @@ -2088,36 +1965,25 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { } ScopedActivateContext activated{context}; - CUresult res = cuEventCreate(result, cuflags); - - if (res == CUDA_SUCCESS) { - return absl::OkStatus(); - } else if (res == CUDA_ERROR_OUT_OF_MEMORY) { - return absl::ResourceExhaustedError( - "could not create CUDA event: out of device memory"); - } else { - return absl::FailedPreconditionError( - absl::StrCat("could not create CUDA event: ", ToString(res))); - } + return cuda::ToStatus(cuEventCreate(result, cuflags)); } -/* static */ int GpuDriver::GetDeviceCount() { +int GpuDriver::GetDeviceCount() { int device_count = 0; - CUresult res = cuDeviceGetCount(&device_count); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "could not retrieve CUDA device count: " << ToString(res); + auto status = cuda::ToStatus(cuDeviceGetCount(&device_count)); + if (!status.ok()) { + LOG(ERROR) << "could not retrieve CUDA device count: " << status; return 0; } return device_count; } -/* static */ absl::StatusOr GpuDriver::GetPointerContext( - CUdeviceptr pointer) { +absl::StatusOr GpuDriver::GetPointerContext(CUdeviceptr pointer) { GpuContext* context = nullptr; - CUresult result = - cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, pointer); - if (result == CUDA_SUCCESS) { + auto status = cuda::ToStatus( + cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, pointer)); + if (status.ok()) { // For cudaMallocAsync, the context returned is null. For now // return not-available. But how to manage that correctly // everywhere in TF? Currently this is only used during error @@ -2130,53 +1996,32 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return context; } - return absl::InternalError(absl::StrCat( - "failed to query context for device pointer: ", ToString(result))); + return status; } -/* static */ absl::StatusOr GpuDriver::GetPointerMemorySpace( +absl::StatusOr GpuDriver::GetPointerMemorySpace( CUdeviceptr pointer) { unsigned int value; - CUresult result = - cuPointerGetAttribute(&value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer); - if (result == CUDA_SUCCESS) { - switch (value) { - case CU_MEMORYTYPE_DEVICE: - return MemoryType::kDevice; - case CU_MEMORYTYPE_HOST: - return MemoryType::kHost; - default: - return absl::InternalError( - absl::StrCat("unknown memory space provided by CUDA API: ", value)); - } + TF_RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute( + &value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer))); + switch (value) { + case CU_MEMORYTYPE_DEVICE: + return MemoryType::kDevice; + case CU_MEMORYTYPE_HOST: + return MemoryType::kHost; + default: + return absl::InternalError( + absl::StrCat("unknown memory space provided by CUDA API: ", value)); } - - return absl::InternalError(absl::StrCat( - "failed to query pointer for memory space: ", ToString(result))); } -/* static */ absl::Status GpuDriver::GetPointerAddressRange(CUdeviceptr dptr, - CUdeviceptr* base, - size_t* size) { - CUresult result = cuMemGetAddressRange(base, size, dptr); - if (result == CUDA_SUCCESS) { - return absl::OkStatus(); - } else if (result == CUDA_ERROR_NOT_FOUND) { - // We differentiate between "this pointer is unknown" (return here) and - // "there was an internal error while performing this operation" (return - // below). - return absl::NotFoundError(absl::StrFormat("not a device pointer %p; %s", - reinterpret_cast(dptr), - ToString(result))); - } - - return absl::InternalError( - absl::StrFormat("failed to get pointer into for device pointer %p; %s", - reinterpret_cast(dptr), ToString(result))); +absl::Status GpuDriver::GetPointerAddressRange(CUdeviceptr dptr, + CUdeviceptr* base, + size_t* size) { + return cuda::ToStatus(cuMemGetAddressRange(base, size, dptr)); } -/* static */ absl::StatusOr GpuDriver::GetPointerDevice( - CUdeviceptr pointer) { +absl::StatusOr GpuDriver::GetPointerDevice(CUdeviceptr pointer) { auto result = GetPointerContext(pointer); if (!result.ok()) { return result.status(); @@ -2185,39 +2030,25 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { return DeviceFromContext(result.value()); } -/* static */ absl::Status GpuDriver::GetComputeCapability(int* cc_major, - int* cc_minor, - CUdevice device) { +absl::Status GpuDriver::GetComputeCapability(int* cc_major, int* cc_minor, + CUdevice device) { *cc_major = 0; *cc_minor = 0; - CUresult res = cuDeviceGetAttribute( - cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); - if (res != CUDA_SUCCESS) { - return absl::InternalError(absl::StrFormat( - "failed to get compute capability major for device: %s; %d", - ToString(res), device)); - } - - res = cuDeviceGetAttribute( - cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); - if (res != CUDA_SUCCESS) { - return absl::InternalError(absl::StrFormat( - "failed to get compute capability minor for device: %s; %d", - ToString(res), device)); - } + TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( + cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device))); - return absl::OkStatus(); + return cuda::ToStatus(cuDeviceGetAttribute( + cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)); } -/* static */ absl::Status GpuDriver::GetGpuISAVersion(int* version, - CUdevice device) { +absl::Status GpuDriver::GetGpuISAVersion(int* version, CUdevice device) { return absl::Status{ absl::StatusCode::kInternal, "Feature not supported on CUDA platform (GetGpuISAVersion)"}; } -/* static */ absl::Status GpuDriver::GetGpuGCNArchName(CUdevice, std::string*) { +absl::Status GpuDriver::GetGpuGCNArchName(CUdevice, std::string*) { return absl::Status{ absl::StatusCode::kInternal, "Feature not supported on CUDA platform (GetGpuGCNArchName)"}; @@ -2229,27 +2060,24 @@ template static absl::StatusOr GetSimpleAttribute(CUdevice device, CUdevice_attribute attribute) { int value = -1; - RETURN_IF_CUDA_RES_ERROR(cuDeviceGetAttribute(&value, attribute, device), - "Could not retrieve CUDA device attribute (", - attribute); + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuDeviceGetAttribute(&value, attribute, device), + absl::StrCat("Could not retrieve CUDA device attribute (", attribute))); T converted = value; return converted; } -/* static */ absl::StatusOr GpuDriver::GetMultiprocessorCount( - CUdevice device) { +absl::StatusOr GpuDriver::GetMultiprocessorCount(CUdevice device) { return GetSimpleAttribute(device, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT); } -/* static */ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerCore( - CUdevice device) { +absl::StatusOr GpuDriver::GetMaxSharedMemoryPerCore(CUdevice device) { return GetSimpleAttribute( device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR); } -/* static */ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlock( - CUdevice device) { +absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlock(CUdevice device) { return GetSimpleAttribute( device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK); } @@ -2260,85 +2088,73 @@ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN); } -/* static */ absl::StatusOr GpuDriver::GetMaxThreadsPerMultiprocessor( +absl::StatusOr GpuDriver::GetMaxThreadsPerMultiprocessor( CUdevice device) { return GetSimpleAttribute( device, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR); } -/* static */ absl::StatusOr GpuDriver::GetMaxThreadsPerBlock( - CUdevice device) { +absl::StatusOr GpuDriver::GetMaxThreadsPerBlock(CUdevice device) { return GetSimpleAttribute(device, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK); } -/* static */ absl::StatusOr GpuDriver::GetMaxRegistersPerBlock( - CUdevice device) { +absl::StatusOr GpuDriver::GetMaxRegistersPerBlock(CUdevice device) { return GetSimpleAttribute( device, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK); } -/* static */ absl::StatusOr GpuDriver::GetThreadsPerWarp( - CUdevice device) { +absl::StatusOr GpuDriver::GetThreadsPerWarp(CUdevice device) { return GetSimpleAttribute(device, CU_DEVICE_ATTRIBUTE_WARP_SIZE); } -/* static */ absl::Status GpuDriver::GetGridLimits(int* x, int* y, int* z, - CUdevice device) { +absl::Status GpuDriver::GetGridLimits(int* x, int* y, int* z, CUdevice device) { int value; - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device), - "Could not get device attribute"); + "Could not get device attribute")); *x = value; - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device), - "Could not get device attribute"); + "Could not get device attribute")); *y = value; - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device), - "Could not get device attribute"); + "Could not get device attribute")); *z = value; return absl::OkStatus(); } -/* static */ absl::StatusOr GpuDriver::GetDriverVersion() { +absl::StatusOr GpuDriver::GetDriverVersion() { int32_t version; - RETURN_IF_CUDA_RES_ERROR(cuDriverGetVersion(&version), - "Could not get driver version"); + TF_RETURN_IF_ERROR(cuda::ToStatus(cuDriverGetVersion(&version), + "Could not get driver version")); return version; } -/* static */ bool GpuDriver::GetDeviceProperties(CUdevprop* device_properties, - int device_ordinal) { - CUresult res = cuDeviceGetProperties(device_properties, device_ordinal); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to query device properties: " << ToString(res); - return false; - } - - return true; +bool GpuDriver::GetDeviceProperties(CUdevprop* device_properties, + int device_ordinal) { + auto status = + cuda::ToStatus(cuDeviceGetProperties(device_properties, device_ordinal)); + return status.ok(); } -/* static */ absl::StatusOr GpuDriver::GetDeviceAttribute( - CUdevice_attribute attribute, CUdevice device) { +absl::StatusOr GpuDriver::GetDeviceAttribute(CUdevice_attribute attribute, + CUdevice device) { int val; - CUresult res = cuDeviceGetAttribute(&val, attribute, device); - if (res != CUDA_SUCCESS) { - return absl::InternalError( - absl::StrFormat("failed to get device attribute %d for device %d: %s", - attribute, device, ToString(res))); - } + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuDeviceGetAttribute(&val, attribute, device))); return val; } -/* static */ bool GpuDriver::IsEccEnabled(CUdevice device, bool* result) { +bool GpuDriver::IsEccEnabled(CUdevice device, bool* result) { int value = -1; - CUresult res = - cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_ECC_ENABLED, device); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to query ECC status: " << ToString(res); + auto status = cuda::ToStatus( + cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_ECC_ENABLED, device)); + if (!status.ok()) { + LOG(ERROR) << "failed to query ECC status: " << status; return false; } @@ -2346,15 +2162,14 @@ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( return true; } -/* static */ bool GpuDriver::GetDeviceMemoryInfo(GpuContext* context, - int64_t* free_out, - int64_t* total_out) { +bool GpuDriver::GetDeviceMemoryInfo(GpuContext* context, int64_t* free_out, + int64_t* total_out) { ScopedActivateContext activation(context); size_t free = 0; size_t total = 0; - CUresult res = cuMemGetInfo(&free, &total); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to query device memory info: " << ToString(res); + auto status = cuda::ToStatus(cuMemGetInfo(&free, &total)); + if (!status.ok()) { + LOG(ERROR) << "failed to query device memory info: " << status; return false; } @@ -2363,12 +2178,11 @@ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( return true; } -/* static */ bool GpuDriver::GetDeviceTotalMemory(CUdevice device, - uint64_t* result) { +bool GpuDriver::GetDeviceTotalMemory(CUdevice device, uint64_t* result) { size_t value{}; - CUresult res = cuDeviceTotalMem(&value, device); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to query total available memory: " << ToString(res); + auto status = cuda::ToStatus(cuDeviceTotalMem(&value, device)); + if (!status.ok()) { + LOG(ERROR) << "failed to query total available memory: " << status; return false; } @@ -2376,22 +2190,22 @@ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( return true; } -/* static */ std::string GpuDriver::GetPCIBusID(CUdevice device) { +std::string GpuDriver::GetPCIBusID(CUdevice device) { std::string pci_bus_id; static const int kBufferSize = 64; absl::InlinedVector chars(kBufferSize); chars[kBufferSize - 1] = '\0'; - CUresult res = cuDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res); + auto status = cuda::ToStatus( + cuDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device)); + if (!status.ok()) { + LOG(ERROR) << "failed to query PCI bus id for device: " << status; return pci_bus_id; } pci_bus_id = chars.begin(); return pci_bus_id; } -/* static */ bool GpuDriver::CanEnablePeerAccess(GpuContext* from, - GpuContext* to) { +bool GpuDriver::CanEnablePeerAccess(GpuContext* from, GpuContext* to) { if (from == to) { return true; // A context can always access its own memory. } @@ -2411,20 +2225,18 @@ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( return CanEnablePeerAccess(from_device.value(), to_device.value()); } -/* static */ bool GpuDriver::CanEnablePeerAccess(GpuDeviceHandle from, - GpuDeviceHandle to) { +bool GpuDriver::CanEnablePeerAccess(GpuDeviceHandle from, GpuDeviceHandle to) { int can_access_peer = -1; - CUresult result = cuDeviceCanAccessPeer(&can_access_peer, from, to); - if (result != CUDA_SUCCESS) { - LOG(ERROR) << "failed to detect peer access capability: " - << ToString(result); + auto status = + cuda::ToStatus(cuDeviceCanAccessPeer(&can_access_peer, from, to)); + if (!status.ok()) { + LOG(ERROR) << "failed to detect peer access capability: " << status; return false; } return can_access_peer; } -/* static */ absl::Status GpuDriver::EnablePeerAccess(GpuContext* from, - GpuContext* to) { +absl::Status GpuDriver::EnablePeerAccess(GpuContext* from, GpuContext* to) { if (from == to) { return absl::OkStatus(); // A context can always access its own // memory. @@ -2436,23 +2248,23 @@ absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( result != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) { return absl::InternalError( absl::StrFormat("failed to enable peer access from %p to %p: %s", from, - to, ToString(result))); + to, cuda::ToStatus(result).ToString())); } return absl::OkStatus(); } -/* static */ absl::StatusOr GpuDriver::GetMaxOccupiedBlocksPerCore( +absl::StatusOr GpuDriver::GetMaxOccupiedBlocksPerCore( GpuContext* context, CUfunction kernel, int threads_per_block, size_t dynamic_shared_memory_bytes) { ScopedActivateContext activation(context); int max_blocks; - RETURN_IF_CUDA_RES_ERROR( + TF_RETURN_IF_ERROR(cuda::ToStatus( cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( &max_blocks, kernel, threads_per_block, dynamic_shared_memory_bytes, CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE), - absl::StrFormat("Failed to calculate occupancy of kernel %p", kernel)); + absl::StrFormat("Failed to calculate occupancy of kernel %p", kernel))); return max_blocks; } @@ -2462,8 +2274,8 @@ namespace cuda { CUcontext CurrentContextOrDie() { CUcontext current = nullptr; - FAIL_IF_CUDA_RES_ERROR(cuCtxGetCurrent(¤t), - "Failed to query current context"); + TF_CHECK_OK(cuda::ToStatus(cuCtxGetCurrent(¤t), + "Failed to query current context")); return current; } diff --git a/xla/stream_executor/cuda/cuda_driver.h b/xla/stream_executor/cuda/cuda_driver.h index 96b3428e2c9a94..5c04ab6ccbee02 100644 --- a/xla/stream_executor/cuda/cuda_driver.h +++ b/xla/stream_executor/cuda/cuda_driver.h @@ -31,22 +31,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/gpu_driver.h" namespace stream_executor { namespace gpu { -// Formats CUresult to output prettified values into a log stream. -static std::string ToString(CUresult result) { - const char* error_name; - if (cuGetErrorName(result, &error_name)) { - return absl::StrCat("UNKNOWN ERROR (", static_cast(result), ")"); - } - const char* error_string; - if (cuGetErrorString(result, &error_string)) { - return error_name; - } - return absl::StrCat(error_name, ": ", error_string); -} // Polls (without blocking) to determine the status of an event - pending or // complete (or an error status). @@ -127,12 +116,13 @@ class CreatedContexts { // Find device id from cuda pointer value. static int GetDeviceOrdinal(void* ptr) { int device_ordinal; - CUresult result = cuPointerGetAttribute(static_cast(&device_ordinal), - CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, - reinterpret_cast(ptr)); - if (result != CUDA_SUCCESS) { + absl::Status status = cuda::ToStatus( + cuPointerGetAttribute(static_cast(&device_ordinal), + CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(ptr))); + if (!status.ok()) { LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr - << ". Error: " << ToString(result); + << ". Error: " << status; } return device_ordinal; } diff --git a/xla/stream_executor/cuda/cuda_driver_test.cc b/xla/stream_executor/cuda/cuda_driver_test.cc index da4d78118f51d2..7cb402a91ca43a 100644 --- a/xla/stream_executor/cuda/cuda_driver_test.cc +++ b/xla/stream_executor/cuda/cuda_driver_test.cc @@ -13,27 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/cuda/cuda_driver.h" + #include "absl/log/log.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" +#include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/cuda/cuda_driver.h" - -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "tsl/platform/status.h" #include "tsl/platform/test.h" namespace stream_executor { namespace gpu { void CheckCuda(CUresult result, const char* file, int line) { - if (result == CUDA_SUCCESS) { - return; - } - const char* name; - cuGetErrorName(result, &name); - const char* message; - cuGetErrorString(result, &message); - LOG(FATAL) << file << "(" << line << "): " << name << ", " << message; + TF_CHECK_OK(cuda::ToStatus(result)); } void CheckCuda(cudaError_t result, const char* file, int line) { diff --git a/xla/stream_executor/cuda/cuda_status.cc b/xla/stream_executor/cuda/cuda_status.cc new file mode 100644 index 00000000000000..aa6fdd1498a311 --- /dev/null +++ b/xla/stream_executor/cuda/cuda_status.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_status.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "third_party/gpus/cuda/include/cuda.h" + +namespace stream_executor::cuda::internal { + +absl::Status ToStatusSlow(CUresult result, absl::string_view detail) { + const char* error_name; + std::string error_detail; + if (cuGetErrorName(result, &error_name)) { + error_detail = absl::StrCat(detail, ": UNKNOWN ERROR (", + static_cast(result), ")"); + } else { + const char* error_string; + if (cuGetErrorString(result, &error_string)) { + error_detail = absl::StrCat(detail, ": ", error_name); + } else { + error_detail = absl::StrCat(detail, ": ", error_name, ": ", error_string); + } + } + + if (result == CUDA_ERROR_OUT_OF_MEMORY) { + return absl::ResourceExhaustedError(error_detail); + } else if (result == CUDA_ERROR_NOT_FOUND) { + return absl::NotFoundError(error_detail); + } else { + return absl::InternalError(absl::StrCat("CUDA error: ", error_detail)); + } +} + +} // namespace stream_executor::cuda::internal diff --git a/xla/stream_executor/cuda/cuda_status.h b/xla/stream_executor/cuda/cuda_status.h new file mode 100644 index 00000000000000..ac2013860afc13 --- /dev/null +++ b/xla/stream_executor/cuda/cuda_status.h @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_STATUS_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_STATUS_H_ + +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "third_party/gpus/cuda/include/cuda.h" + +namespace stream_executor::cuda { + +namespace internal { +// Helper method to handle the slow path of ToStatus. Assumes a non-successful +// result code. +absl::Status ToStatusSlow(CUresult result, absl::string_view detail); +} // namespace internal + +// Returns an absl::Status corresponding to the CUresult. +inline absl::Status ToStatus(CUresult result, absl::string_view detail = "") { + if (ABSL_PREDICT_TRUE(result == CUDA_SUCCESS)) { + return absl::OkStatus(); + } + return internal::ToStatusSlow(result, detail); +} + +} // namespace stream_executor::cuda + +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_STATUS_H_ diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 8e918262abfc55..1da212056f868e 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -584,6 +584,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_activation", "//xla/stream_executor/cuda:cuda_executor", + "//xla/stream_executor/cuda:cuda_status", "//xla/tsl/framework:allocator", "//xla/tsl/framework:device_id", "//xla/tsl/util:env_var", diff --git a/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc b/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc index ae567a0c46bce0..7296519fbc40c6 100644 --- a/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc +++ b/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_activation.h" +#include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/allocator.h" @@ -41,15 +42,6 @@ limitations under the License. namespace stream_executor { -static std::string GetCudaErrorMessage(CUresult result) { - const char* error; - cuGetErrorString(result, &error); - const char* name; - cuGetErrorName(result, &name); - return absl::StrCat("CUDA error: ", error ? error : "", " (", - name ? name : "Unknown", ")"); -} - struct GpuCudaMallocAsyncAllocator::CudaState { // cudaMallocAsync is stream aware. But TF StreamExecutor use only 1 // compute stream and already synchronize with the h2d, d2h and d2d @@ -88,27 +80,27 @@ void GpuCudaMallocAsyncAllocator::PrintAllocatorStatisticsNoLock() { CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT, &mem_reserved_current)) { LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: " - << GetCudaErrorMessage(result); + << cuda::ToStatus(result); } cuuint64_t mem_used_current; if (auto result = cuMemPoolGetAttribute(cuda_state_->pool, CU_MEMPOOL_ATTR_USED_MEM_CURRENT, &mem_used_current)) { LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: " - << GetCudaErrorMessage(result); + << cuda::ToStatus(result); } cuuint64_t mem_reserved_high; if (auto result = cuMemPoolGetAttribute(cuda_state_->pool, CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH, &mem_reserved_high)) { LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: " - << GetCudaErrorMessage(result); + << cuda::ToStatus(result); } cuuint64_t mem_used_high; if (auto result = cuMemPoolGetAttribute( cuda_state_->pool, CU_MEMPOOL_ATTR_USED_MEM_HIGH, &mem_used_high)) { LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: " - << GetCudaErrorMessage(result); + << cuda::ToStatus(result); } LOG(ERROR) << "CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT: " << mem_reserved_current; @@ -151,7 +143,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( CUcontext pctx; // We loose track of it. But this is fine. if (auto result = cuDevicePrimaryCtxRetain(&pctx, 0)) LOG(FATAL) // Crash OK. - << "Failed to retain context: " << GetCudaErrorMessage(result); + << "Failed to retain context: " << cuda::ToStatus(result); } cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_}; @@ -159,8 +151,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( // Check the CUDA runtime is recent enough. if (auto status2 = cuDriverGetVersion(&driverVersion)) { LOG(FATAL) // Crash OK. - << "Error while fetching driver version: " - << GetCudaErrorMessage(status2); + << "Error while fetching driver version: " << cuda::ToStatus(status2); } // Check that cudaMallocAsync is supported. @@ -172,7 +163,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( LOG(FATAL) // Crash OK. << "On device: " << platform_device_id.value() << " Current driver: " << driverVersion - << ". Failed to get device attribute : " << GetCudaErrorMessage(status); + << ". Failed to get device attribute : " << cuda::ToStatus(status); } if (!cuda_malloc_async_supported) LOG(FATAL) // Crash OK. @@ -196,13 +187,13 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( #endif // CUDA_VERSION >= 12030 if (auto status = cuMemPoolCreate(&cuda_state_->pool, &pool_props)) LOG(FATAL) << // Crash OK. - "Failed to create CUDA pool: " << GetCudaErrorMessage(status); + "Failed to create CUDA pool: " << cuda::ToStatus(status); } else { pool_size = reserve_memory_size; if (auto status = cuDeviceGetDefaultMemPool(&cuda_state_->pool, platform_device_id.value())) LOG(FATAL) << // Crash OK. - "Failed to get default CUDA pool: " << GetCudaErrorMessage(status); + "Failed to get default CUDA pool: " << cuda::ToStatus(status); VLOG(2) << "using default memory pool " << cuda_state_->pool; } @@ -214,7 +205,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, &release_threshold_64)) LOG(FATAL) << // Crash OK. - "Failed to set CUDA pool attribute: " << GetCudaErrorMessage(status); + "Failed to set CUDA pool attribute: " << cuda::ToStatus(status); if (compute_stats) { stats_ = std::make_unique(); @@ -232,13 +223,13 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( cuda_state_->pool, CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC, &disable)) { LOG(FATAL) << // Crash OK. - "Failed to set CUDA pool attribute: " << GetCudaErrorMessage(status); + "Failed to set CUDA pool attribute: " << cuda::ToStatus(status); } if (auto status = cuMemPoolSetAttribute( cuda_state_->pool, CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES, &disable)) { LOG(FATAL) << // Crash OK. - "Failed to set CUDA pool attribute: " << GetCudaErrorMessage(status); + "Failed to set CUDA pool attribute: " << cuda::ToStatus(status); } } @@ -278,7 +269,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( LOG(FATAL) // Crash OK. << "cuDeviceCanAccessPeer failed to know if GPU id " << map.location.id << " can access GPU id " - << platform_device_id.value() << ": " << GetCudaErrorMessage(status); + << platform_device_id.value() << ": " << cuda::ToStatus(status); } if (canAccessPeer == 1) { if (auto status = cuMemPoolSetAccess(cuda_state_->pool, &map, 1)) { @@ -286,7 +277,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( LOG(FATAL) // Crash OK. << "Error when setting access to the pool id: " << i << " location id: " << map.location.id - << " error: " << GetCudaErrorMessage(status); + << " error: " << cuda::ToStatus(status); } } @@ -300,7 +291,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( platform_device_id.value())) { cuda_state_->pool = nullptr; LOG(FATAL) // Crash OK. - << "cuDeviceCanAccessPeer failed: " << GetCudaErrorMessage(status); + << "cuDeviceCanAccessPeer failed: " << cuda::ToStatus(status); } if (canAccessPeer == 1) { if (auto status = cuMemPoolSetAccess((*all_pools_)[i], &map, 1)) { @@ -308,7 +299,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( LOG(FATAL) // Crash OK. << "Error when setting access to the pool id: " << previous_pool_id << " location id: " << map.location.id - << " error: " << GetCudaErrorMessage(status); + << " error: " << cuda::ToStatus(status); } } } @@ -329,8 +320,7 @@ GpuCudaMallocAsyncAllocator::~GpuCudaMallocAsyncAllocator() { VLOG(2) << "Delete memory pool " << reinterpret_cast(cuda_state_->pool); if (auto status = cuMemPoolDestroy(cuda_state_->pool)) - LOG(FATAL) << "Failed to destroy memory pool:" - << GetCudaErrorMessage(status); + LOG(FATAL) << "Failed to destroy memory pool:" << cuda::ToStatus(status); } } @@ -368,7 +358,7 @@ void* GpuCudaMallocAsyncAllocator::AllocateRaw(size_t alignment, size_t free, total; cuMemGetInfo(&free, &total); LOG(ERROR) << Name() << " cuMemAllocAsync failed to allocate " << num_bytes - << " bytes: " << GetCudaErrorMessage(result) + << " bytes: " << cuda::ToStatus(result) << "\n Reported by CUDA: Free memory/Total memory: " << free << "/" << total; if (stats_) { @@ -416,13 +406,13 @@ void GpuCudaMallocAsyncAllocator::DeallocateRaw(void* ptr) { // It happens with multi-GPU that TF free the GPU allocation after // the driver is unloaded. It is safe to ignore this error here. // TODO: Find how to fix the shutdown steps in TF. - VLOG(1) << "Ignoring CUDA error: " << GetCudaErrorMessage(result); + VLOG(1) << "Ignoring CUDA error: " << cuda::ToStatus(result); } else { size_t free, total; cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_}; cuMemGetInfo(&free, &total); LOG(ERROR) << "cudaFreeAsync failed to free " << ptr << ": " - << GetCudaErrorMessage(result) + << cuda::ToStatus(result) << "\n Free memory/Total memory: " << free << "/" << total; if (stats_) { LOG(ERROR) << "Stats: " << stats_->DebugString(); @@ -490,7 +480,7 @@ void GpuCudaMallocAsyncAllocator::SetStreamAndPreallocateMemory(void* stream) { CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, &pool_size_64)) { LOG(FATAL) << // Crash OK. - "Failed to get CUDA pool attribute: " << GetCudaErrorMessage(status); + "Failed to get CUDA pool attribute: " << cuda::ToStatus(status); } cuda_state_->cuda_stream = new_cuda_stream; int64_t prealloc_size = 0; From dd0395e99412bd7147af709c9d9b4215f84551f2 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Thu, 25 Jul 2024 16:00:00 -0700 Subject: [PATCH 175/376] Remove deleted patch file from `third_party/triton/llvm_integration/series.bzl` PiperOrigin-RevId: 656126198 --- third_party/triton/llvm_integration/series.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 9d0e1204ba527f..656b9c894904d8 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,6 +8,5 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton/llvm_integration:cl656020169.patch", # Add new patches just above this line ] From bebc10ce8522b4ec4d78d318a6d56fd40f108830 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Thu, 25 Jul 2024 17:04:25 -0700 Subject: [PATCH 176/376] [xla:gpu] NFC: Do not modify command buffer flag in gemm fusion autotuner PiperOrigin-RevId: 656148015 --- xla/service/gpu/gemm_fusion_autotuner.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/xla/service/gpu/gemm_fusion_autotuner.cc b/xla/service/gpu/gemm_fusion_autotuner.cc index 8bc3650c3325df..0a6188495febf2 100644 --- a/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/gemm_fusion_autotuner.cc @@ -328,9 +328,6 @@ absl::StatusOr> TritonGemmAutotuneExtractor( bool allow_filtering_kernels_spilling_registers) { std::unique_ptr new_module = ExtractInstructionIntoNewModule(*fusion); - // TODO(anlunx): Disable command buffers for now because it breaks triton - // autotuner test. Enable this when the function of command buffers is stable. - debug_opts.clear_xla_gpu_enable_command_buffer(); if (!allow_filtering_kernels_spilling_registers) { debug_opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning( false); From c2b48b8219de2ce6ee90216f6a2239b319307d12 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Jul 2024 17:38:06 -0700 Subject: [PATCH 177/376] [XLA:MSA] Support slice instruction in runtime simulator This patch includes the slice instruction overhead into the runtime simulator. The slice instruction introduces async transfer between default<->alternate memory space, which is the same as the copy-start copy-done instruction, except the slice instructions have different way to calculate the transfer bytes. PiperOrigin-RevId: 656158889 --- xla/service/memory_space_assignment/BUILD | 3 +- .../memory_space_assignment.cc | 4 +- .../memory_space_assignment/simulator.cc | 114 ++++++----- .../memory_space_assignment/simulator.h | 60 +++--- .../memory_space_assignment/simulator_test.cc | 185 +++++++++++++----- 5 files changed, 235 insertions(+), 131 deletions(-) diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index 56bf642315feca..a8efb1d1050c62 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -332,7 +332,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:hlo_alias_analysis", - "//xla/service:hlo_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -353,6 +352,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_live_range", "//xla/service:hlo_alias_analysis", "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_value", + "//xla/service/heap_simulator", "//xla/tests:hlo_test_base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", diff --git a/xla/service/memory_space_assignment/memory_space_assignment.cc b/xla/service/memory_space_assignment/memory_space_assignment.cc index f2781c6c0f9a41..753a348e14a76f 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -358,8 +358,8 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( runtime_simulator.emplace(options_.cost_analysis, options_.alternate_memory_space); float estimated_time = - runtime_simulator->SimulateElapsedTimeWithoutAsyncCopies(hlo_live_range, - allocations_); + runtime_simulator->SimulateElapsedTimeWithoutAsyncCopyLikes( + hlo_live_range, allocations_); VLOG(1) << "Estimated elapsed time without async copies (sec): " << estimated_time; } diff --git a/xla/service/memory_space_assignment/simulator.cc b/xla/service/memory_space_assignment/simulator.cc index b618061ec1c6ef..d547c1e65eb998 100644 --- a/xla/service/memory_space_assignment/simulator.cc +++ b/xla/service/memory_space_assignment/simulator.cc @@ -28,12 +28,12 @@ limitations under the License. #include "absl/log/log.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/layout.h" #include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -62,7 +62,7 @@ void RuntimeSimulator::InitializeAlternateMemoryMap( } } -float RuntimeSimulator::SimulateElapsedTimeWithoutAsyncCopies( +float RuntimeSimulator::SimulateElapsedTimeWithoutAsyncCopyLikes( const HloLiveRange& hlo_live_range, const AllocationSequence& allocations) { InitializeAlternateMemoryMap(allocations); const auto& instruction_sequence = @@ -99,17 +99,32 @@ float RuntimeSimulator::SimulateElapsedTimeWithoutAsyncCopies( return total_elapsed; } -MemoryTransferDirection GetAsyncCopyDirection( - const HloInstruction* async_copy_start, int64_t alternate_memory_space) { - CHECK_EQ(async_copy_start->opcode(), HloOpcode::kCopyStart); +bool IsAsyncCopyLikeStart(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCopyStart || + (instruction->opcode() == HloOpcode::kAsyncStart && + instruction->async_wrapped_instruction()->opcode() == + HloOpcode::kSlice); +} - int64_t operand_memory_space = - async_copy_start->operand(0)->shape().layout().memory_space(); +bool IsAsyncCopyLikeDone(const HloInstruction* instruction) { + return (instruction->opcode() == HloOpcode::kCopyDone || + (instruction->opcode() == HloOpcode::kAsyncDone && + instruction->async_wrapped_instruction()->opcode() == + HloOpcode::kSlice)); +} +MemoryTransferDirection GetAsyncCopyLikeDirection( + const HloInstruction* async_copy_like_start, + int64_t alternate_memory_space) { + CHECK(IsAsyncCopyLikeStart(async_copy_like_start)); + + int64_t operand_memory_space = + async_copy_like_start->operand(0)->shape().layout().memory_space(); // Get all users std::optional output_memory_space; - for (const HloInstruction* user : async_copy_start->users()) { - if (user->opcode() == HloOpcode::kCopyDone) { + for (const HloInstruction* user : async_copy_like_start->users()) { + if (user->opcode() == HloOpcode::kCopyDone || + user->opcode() == HloOpcode::kAsyncDone) { output_memory_space.emplace(user->shape().layout().memory_space()); break; } @@ -129,56 +144,58 @@ MemoryTransferDirection GetAsyncCopyDirection( return MemoryTransferDirection::kUnsupported; } -const std::list& +const std::list& RuntimeSimulator::GetOutstandingReadDefaultQueue() const { return outstanding_read_default_queue_; } -const std::list& +const std::list& RuntimeSimulator::GetOutstandingWriteDefaultQueue() const { return outstanding_write_default_queue_; } const HloInstruction* RuntimeSimulator::RemoveBytesFromQueueIfNotEmpty( - std::list& async_copy_queue, float processed_bytes) { - if (async_copy_queue.empty()) return nullptr; - CHECK_GE(async_copy_queue.front().remaining_bytes_to_transfer, + std::list& async_copy_like_queue, + float processed_bytes) { + if (async_copy_like_queue.empty()) return nullptr; + CHECK_GE(async_copy_like_queue.front().remaining_bytes_to_transfer, processed_bytes); - async_copy_queue.front().remaining_bytes_to_transfer -= processed_bytes; - if (async_copy_queue.front().remaining_bytes_to_transfer == 0.0) { + async_copy_like_queue.front().remaining_bytes_to_transfer -= processed_bytes; + if (async_copy_like_queue.front().remaining_bytes_to_transfer == 0.0) { const HloInstruction* retired_instruction = - async_copy_queue.front().copy_start_inst; - async_copy_queue.pop_front(); + async_copy_like_queue.front().copy_like_start_inst; + async_copy_like_queue.pop_front(); return retired_instruction; } return nullptr; } -float RuntimeSimulator::SimulateAsyncCopyDone( - const HloInstruction* copy_done_instruction) { - const HloInstruction* copy_start_instruction = - copy_done_instruction->operand(0); - MemoryTransferDirection direction = - GetAsyncCopyDirection(copy_start_instruction, alternate_memory_space_); +float RuntimeSimulator::SimulateAsyncCopyLikeDone( + const HloInstruction* copy_like_done_instruction) { + const HloInstruction* copy_like_start_instruction = + copy_like_done_instruction->operand(0); + MemoryTransferDirection direction = GetAsyncCopyLikeDirection( + copy_like_start_instruction, alternate_memory_space_); if (direction == MemoryTransferDirection::kUnsupported) { // The memory access is not a default <-> alternate memory copy. LOG(WARNING) << "Unsupported memory transfer direction for copy-done: " - << copy_done_instruction->ToString(); + << copy_like_done_instruction->ToString(); return 0.0; } - std::list& same_direction_queue = + std::list& same_direction_queue = direction == MemoryTransferDirection::kDefaultToAlternate ? outstanding_read_default_queue_ : outstanding_write_default_queue_; - std::list& opposite_direction_queue = + std::list& opposite_direction_queue = direction == MemoryTransferDirection::kDefaultToAlternate ? outstanding_write_default_queue_ : outstanding_read_default_queue_; - if (absl::c_find_if( - same_direction_queue, [&](const OutstandingAsyncCopy& async_copy) { - return async_copy.copy_start_inst == copy_start_instruction; - }) == same_direction_queue.end()) { + if (absl::c_find_if(same_direction_queue, + [&](const OutstandingAsyncCopyLike& async_copy_like) { + return async_copy_like.copy_like_start_inst == + copy_like_start_instruction; + }) == same_direction_queue.end()) { // The copy has already finished; thus, the copy-done takes no time. return 0.0; } @@ -186,7 +203,7 @@ float RuntimeSimulator::SimulateAsyncCopyDone( // Each iteration of the while loop simulates transferring a number of // bytes from each queue that is equal to the smaller of the two elements // at the front of each queue. If that causes us to finish a copy in the - // same_direction_queue, and that copy is the copy_done_instruction, we + // same_direction_queue, and that copy is the copy_like_done_instruction, we // break the loop. float elapsed_time = 0.0; const HloInstruction* retired_instruction_in_same_direction_queue = nullptr; @@ -211,7 +228,7 @@ float RuntimeSimulator::SimulateAsyncCopyDone( retired_instruction_in_same_direction_queue = RemoveBytesFromQueueIfNotEmpty(same_direction_queue, bytes_to_process); } while (retired_instruction_in_same_direction_queue != - copy_start_instruction); + copy_like_start_instruction); return elapsed_time; }; @@ -227,22 +244,22 @@ float RuntimeSimulator::SimulateComputeInstruction( *instruction, operands_in_alternate_memory, outputs_in_alternate_memory); - // Execute the outstanding async copy in the idle time. - ProcessAsyncCopiesInIdleTime(default_memory_idle_time); + // Execute the outstanding async copy likes in the idle time. + ProcessAsyncCopyLikesInIdleTime(default_memory_idle_time); float inst_elapsed = cost_analysis_->GetInstructionElapsedInAlternateMemory( *instruction, operands_in_alternate_memory, outputs_in_alternate_memory); return inst_elapsed; } -void RuntimeSimulator::ProcessAsyncCopiesInIdleTime(float time) { +void RuntimeSimulator::ProcessAsyncCopyLikesInIdleTime(float time) { if (time <= 0.0) { return; } float remaining_simulation_time = time; // This loop simulates the execution of the front memory requests in the // read and/or write queues. The loop terminates when the remaining time is - // exhausted or there are no more outstanding async copies. + // exhausted or there are no more outstanding async copy likes. while ((!outstanding_read_default_queue_.empty() || !outstanding_write_default_queue_.empty()) && remaining_simulation_time > 0.0) { @@ -300,34 +317,33 @@ float RuntimeSimulator::SimulateElapsedTime( } if (instruction->parent()->IsAsyncComputation()) { // We assume the overhead of async computations can be hidden perfectly. - // We plan to integrate the async copy overhead analysis later - // (b/351913186). continue; } - if (instruction->opcode() == HloOpcode::kCopyStart) { + if (IsAsyncCopyLikeStart(instruction)) { // Try to categorize the async copy instruction into // read-from-default and write-to-default queues. MemoryTransferDirection direction = - GetAsyncCopyDirection(instruction, alternate_memory_space_); - const Shape& transfer_shape = instruction->operand(0)->shape(); + GetAsyncCopyLikeDirection(instruction, alternate_memory_space_); + const Shape& transfer_shape = + (instruction->opcode() == HloOpcode::kCopyStart) + ? instruction->operand(0)->shape() + : ShapeUtil::GetSubshape(instruction->shape(), + /*index=*/{1}); float transfer_bytes = static_cast( cost_analysis_->base_costs().GetShapeSize(transfer_shape)); if (direction == MemoryTransferDirection::kDefaultToAlternate) { outstanding_read_default_queue_.push_back( - OutstandingAsyncCopy{instruction, transfer_bytes}); + OutstandingAsyncCopyLike{instruction, transfer_bytes}); } else if (direction == MemoryTransferDirection::kAlternateToDefault) { outstanding_write_default_queue_.push_back( - OutstandingAsyncCopy{instruction, transfer_bytes}); + OutstandingAsyncCopyLike{instruction, transfer_bytes}); } else { // The copy does not involve default memory. } - } else if (instruction->opcode() == HloOpcode::kCopyDone) { - inst_elapsed = SimulateAsyncCopyDone(instruction); + } else if (IsAsyncCopyLikeDone(instruction)) { + inst_elapsed = SimulateAsyncCopyLikeDone(instruction); } else { // This branch is for the compute instructions. - // TODO(b/351913186): Plan to add another branch to handle async - // copy instructions caused by slicing. - absl::Span outputs_in_alternate_memory; auto output_it = outputs_in_alternate_memory_map_.find(instruction); if (output_it != outputs_in_alternate_memory_map_.end()) { diff --git a/xla/service/memory_space_assignment/simulator.h b/xla/service/memory_space_assignment/simulator.h index 729b220760f85e..906322b259a275 100644 --- a/xla/service/memory_space_assignment/simulator.h +++ b/xla/service/memory_space_assignment/simulator.h @@ -39,17 +39,19 @@ enum class MemoryTransferDirection { }; // REQUIRES: -// * async_copy must be an async copy-start instruction. -MemoryTransferDirection GetAsyncCopyDirection(const HloInstruction* async_copy, - int64_t alternate_memory_space); +// * async_copy_like_start must be an async copy-start or slice-start +// instruction. +MemoryTransferDirection GetAsyncCopyLikeDirection( + const HloInstruction* async_copy_like_start, + int64_t alternate_memory_space); -// This struct is used to track the outstanding async copy instructions and +// This struct is used to track the outstanding async copy like instructions and // the remaining bytes required to be accessed. -struct OutstandingAsyncCopy { - const HloInstruction* copy_start_inst; +struct OutstandingAsyncCopyLike { + const HloInstruction* copy_like_start_inst; float remaining_bytes_to_transfer; - bool operator==(const OutstandingAsyncCopy& other) const { - return copy_start_inst == other.copy_start_inst && + bool operator==(const OutstandingAsyncCopyLike& other) const { + return copy_like_start_inst == other.copy_like_start_inst && remaining_bytes_to_transfer == other.remaining_bytes_to_transfer; } }; @@ -66,8 +68,9 @@ class RuntimeSimulator { // testing purpose. explicit RuntimeSimulator( CostAnalysis* cost_analysis, int64_t alternate_memory_space, - const std::list& outstanding_read_default_queue, - const std::list& outstanding_write_default_queue) + const std::list& outstanding_read_default_queue, + const std::list& + outstanding_write_default_queue) : cost_analysis_(cost_analysis), alternate_memory_space_(alternate_memory_space), outstanding_read_default_queue_(outstanding_read_default_queue), @@ -77,19 +80,19 @@ class RuntimeSimulator { // This function provides a basic estimate without considering the overhead of // async copies. - float SimulateElapsedTimeWithoutAsyncCopies( + float SimulateElapsedTimeWithoutAsyncCopyLikes( const HloLiveRange& hlo_live_range, const AllocationSequence& allocations); // Returns the time to simulate the hlo_live_range, when we account for the - // waiting time for async copies to finish. + // waiting time for async copy like instructions to finish. // - // To simulate the overhead of async copies, we need to maintain two queues to - // track the outstanding memory access requests that read/write the default - // memory. When we simulate compute, we use any time there is spare bandwidth - // to simulate async memory accesses to default memory. If we get to an async - // copy done, we must wait until it finishes (potentially waiting for copies - // issued before it to finish. + // To simulate the overhead of async copy like instructions, we need to + // maintain two queues to track the outstanding memory access requests that + // read/write the default memory. When we simulate compute, we use any time + // there is spare bandwidth to simulate async memory accesses to default + // memory. If we get to an async copy like done, we must wait until it + // finishes (potentially waiting for copies issued before it to finish. float SimulateElapsedTime(const HloModule* hlo_module, const AllocationSequence& allocations); @@ -97,8 +100,8 @@ class RuntimeSimulator { // time for executing a copy-done instruction. It returns the // elapsed time (in seconds) for executing the copy-done instruction. // - // This function also updates the passed in queues as we complete async copies - // during the simulation. + // This function also updates the passed in queues as we complete async copy + // like instructions during the simulation. // // We simulate the shared bandwidth for default-alternate memory access. // For example, if the copy-done instruction is a default-write memory @@ -106,11 +109,13 @@ class RuntimeSimulator { // outstanding_read_default_queue, then we use half of the bandwidth to // process both requests in parallel. Otherwise, we use the full bandwidth to // process the default-write request. - float SimulateAsyncCopyDone(const HloInstruction* copy_done_instruction); + float SimulateAsyncCopyLikeDone( + const HloInstruction* copy_like_done_instruction); - const std::list& GetOutstandingReadDefaultQueue() const; + const std::list& GetOutstandingReadDefaultQueue() + const; - const std::list& GetOutstandingWriteDefaultQueue() + const std::list& GetOutstandingWriteDefaultQueue() const; // This is an auxiliary function for simulating the execution @@ -140,17 +145,18 @@ class RuntimeSimulator { // process), the function returns the instruction and pop it from the queue. // Otherwise, it returns nullptr. const HloInstruction* RemoveBytesFromQueueIfNotEmpty( - std::list& async_copy_queue, float processed_bytes); + std::list& async_copy_like_queue, + float processed_bytes); // This is an auxiliary function which simulates the process of draining // the memory access queues in a given amount of time (seconds). If both // outstanding_*_default_queues are non-empty, they share bandwidth. If one of // the queues is empty and the other is not, it gets the full bandwdith. - void ProcessAsyncCopiesInIdleTime(float time); + void ProcessAsyncCopyLikesInIdleTime(float time); int64_t alternate_memory_space_; - std::list outstanding_read_default_queue_; - std::list outstanding_write_default_queue_; + std::list outstanding_read_default_queue_; + std::list outstanding_write_default_queue_; absl::flat_hash_map> outputs_in_alternate_memory_map_; absl::flat_hash_map #include -#include #include #include +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" +#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/shape.h" @@ -61,6 +64,25 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { protected: absl::Status Initialize(absl::string_view hlo_string) { TF_ASSIGN_OR_RETURN(module_, ParseAndReturnVerifiedModule(hlo_string)); + for (HloInstruction* inst : module_->entry_computation()->instructions()) { + instruction_map_[inst->name()] = inst; + // Construct an allocation for the instruction if it is in the alternate + // memory. + if (inst->shape().has_layout() && + inst->shape().layout().memory_space() == kAlternateMemorySpace) { + std::unique_ptr allocation = + std::make_unique( + HloPosition{inst, {}}, + memory_space_assignment::MemorySpace::kAlternate, + HeapSimulator::Chunk::FromOffsetSize(-1, -1), + /*start_time=*/0, + /*end_time=*/1, /*is_scoped_allocation=*/false); + for (HloInstruction* user : inst->users()) { + allocation->AddUse(HloUse{user, 0}); + } + allocations_.push_back(std::move(allocation)); + } + } HloCostAnalysis::Options tpu_device_options; tpu_device_options.shape_size = ShapeSize; // Assume 1 FLOP per second for testing. @@ -88,6 +110,7 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { cost_analysis_.get(), kAlternateMemorySpace); return absl::OkStatus(); } + absl::flat_hash_map instruction_map_; std::unique_ptr hlo_cost_analysis_; std::unique_ptr hlo_cost_analysis_costs_; @@ -132,7 +155,7 @@ TEST_F(MemorySpaceAssignmentSimulatorTest, SingleLayerLoop) { // tuple(%constant.0): 8 * 1 // %greater: 9 * 42 // %loop_result: 8 * 42 - EXPECT_EQ(runtime_simulator_->SimulateElapsedTimeWithoutAsyncCopies( + EXPECT_EQ(runtime_simulator_->SimulateElapsedTimeWithoutAsyncCopyLikes( *hlo_live_range_, allocations_), 1226); EXPECT_EQ( @@ -188,7 +211,7 @@ TEST_F(MemorySpaceAssignmentSimulatorTest, NestedLayerLoop) { // Thus, the total overhead of the while_outer is 1226 * 27 + 12 * 27 + 8 * 1 // + 9 * 27 + 8 * 27 = 33893 - EXPECT_EQ(runtime_simulator_->SimulateElapsedTimeWithoutAsyncCopies( + EXPECT_EQ(runtime_simulator_->SimulateElapsedTimeWithoutAsyncCopyLikes( *hlo_live_range_, allocations_), 33893); EXPECT_EQ( @@ -211,9 +234,9 @@ TEST_F(MemorySpaceAssignmentSimulatorTest, SingleAsyncCopyOverhead) { // Since the HLO does not contain memory access, pass an empty allocation // sequence for test. memory_space_assignment::AllocationSequence allocations; - // The SimulateElapsedTimeWithoutAsyncCopies should not include the overhead - // of async copies. - EXPECT_EQ(runtime_simulator_->SimulateElapsedTimeWithoutAsyncCopies( + // The SimulateElapsedTimeWithoutAsyncCopyLikes should not include the + // overhead of async copies. + EXPECT_EQ(runtime_simulator_->SimulateElapsedTimeWithoutAsyncCopyLikes( *hlo_live_range_, allocations_), 0); // The expected elapsed time is 1024 * 2048 * 4 / 1 = 8388608. @@ -247,35 +270,85 @@ TEST_F(MemorySpaceAssignmentSimulatorTest, AsyncCopyWithComputationOverhead) { runtime_simulator_->SimulateElapsedTime(module_.get(), allocations_), 48); } -class SimulateAsyncCopyDoneTest : public MemorySpaceAssignmentSimulatorTest { +TEST_F(MemorySpaceAssignmentSimulatorTest, SingleAsyncSliceCopyOverhead) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[3072,2048] parameter(0) + slice-start = ((f32[3072,2048]), f32[768,2048]{1,0:S(1)}, s32[]) slice-start(f32[3072,2048] param_0), slice={[1536:2304], [0:2048]} + ROOT slice-done = f32[768,2048]{1,0:T(8,128)S(1)} slice-done(((f32[3072,2048]), f32[768,2048]{1,0:S(1)}, s32[]) slice-start) + } + )"; + TF_ASSERT_OK(Initialize(hlo_string)); + + memory_space_assignment::AllocationSequence allocations; + // The expected elapsed time is 768 * 2048 * 4 / 1 = 6291456. + float expected_elapsed_time = 6291456; + + EXPECT_EQ( + runtime_simulator_->SimulateElapsedTime(module_.get(), allocations_), + expected_elapsed_time); +} + +TEST_F(MemorySpaceAssignmentSimulatorTest, + AsyncCopyAndAsyncSliceAndComputeOverhead) { + absl::string_view hlo_string = + R"(HloModule module, is_scheduled=true + ENTRY Entry { + param_0 = f32[2048] parameter(0) + param_1 = f32[64] parameter(1) + param_2 = f32[128] parameter(2) + slice-start = ((f32[2048]), f32[64]{0:S(1)}, s32[]) slice-start(f32[2048] param_0), slice={[0:64]} + copy-start = (f32[64]{0:S(1)}, f32[64], u32[]) copy-start(f32[64] param_1) + slice-done = f32[64]{0:S(1)} slice-done(((f32[2048]), f32[64]{0:S(1)}, s32[]) slice-start) + copy-done = f32[64]{0:S(1)} copy-done(copy-start) + copy-start-overlap = (f32[128]{0:S(1)}, f32[128], u32[]) copy-start(f32[128] param_2) + add = f32[64]{0:S(1)} add(slice-done, copy-done) + ROOT copy-done-overlap = f32[128]{0:S(1)} copy-done(copy-start-overlap) + } + )"; + TF_ASSERT_OK(Initialize(hlo_string)); + + // The overhead of each instruction is: + // slice-done: 64 * 4 / 1 = 256 sec (default memory access) + // copy-done: 64 * 4 /1 = 256 sec (default memory access) + // add: 3 * 64 * 4 / 2 = 384 sec (alternate memory access) + // Since add does not access default memory, we can use process 384 bytes in + // copy-start-overlap. + // copy-done-overlap: (128 * 4 - 384) / 1 = 128 sec (default memory access) + EXPECT_EQ( + runtime_simulator_->SimulateElapsedTime(module_.get(), allocations_), + 1024); +} + +class SimulateAsyncCopyLikeDoneTest + : public MemorySpaceAssignmentSimulatorTest { protected: absl::Status Initialize(absl::string_view hlo_string) { TF_RETURN_IF_ERROR( MemorySpaceAssignmentSimulatorTest::Initialize(hlo_string)); - for (const HloInstruction* inst : - module_->entry_computation()->instructions()) { - instruction_map_[inst->name()] = inst; - if (inst->name() == "copy-start.1") { - outstanding_read_default_queue_.push_back( - memory_space_assignment::OutstandingAsyncCopy{inst, 512}); - } else if (inst->name() == "copy-start.2") { - outstanding_write_default_queue_.push_back( - memory_space_assignment::OutstandingAsyncCopy{inst, 128}); - } + if (instruction_map_.contains("copy-start.1")) { + outstanding_read_default_queue_.push_back( + memory_space_assignment::OutstandingAsyncCopyLike{ + instruction_map_["copy-start.1"], 512}); + } + if (instruction_map_.contains("copy-start.2")) { + outstanding_write_default_queue_.push_back( + memory_space_assignment::OutstandingAsyncCopyLike{ + instruction_map_["copy-start.2"], 128}); } runtime_simulator_ = std::make_unique( cost_analysis_.get(), kAlternateMemorySpace, outstanding_read_default_queue_, outstanding_write_default_queue_); return absl::OkStatus(); } - std::map instruction_map_; - std::list + std::list outstanding_read_default_queue_; - std::list + std::list outstanding_write_default_queue_; }; -TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyAlreadyCompleted) { +TEST_F(SimulateAsyncCopyLikeDoneTest, AsyncCopyAlreadyCompleted) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { @@ -289,21 +362,21 @@ TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyAlreadyCompleted) { const HloInstruction* copy_done_inst = instruction_map_["copy-done.1"]; // Process the copy-start.1 - runtime_simulator_->SimulateAsyncCopyDone(copy_done_inst); + runtime_simulator_->SimulateAsyncCopyLikeDone(copy_done_inst); // There should be no request in the read/write queues. EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), IsEmpty()); EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); // The function should return 0 for requests that are already completed. float elapsed_time_for_completed_copy = - runtime_simulator_->SimulateAsyncCopyDone(copy_done_inst); + runtime_simulator_->SimulateAsyncCopyLikeDone(copy_done_inst); EXPECT_EQ(elapsed_time_for_completed_copy, 0); // There should be no request in the read/write queues. EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), IsEmpty()); EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); } -TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyFullBandwidth) { +TEST_F(SimulateAsyncCopyLikeDoneTest, AsyncCopyFullBandwidth) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { @@ -318,7 +391,7 @@ TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyFullBandwidth) { // The elapsed time for copy-done.1 is 128 * 4 / 1 = 512. float copy_done_elapsed_time = - runtime_simulator_->SimulateAsyncCopyDone(copy_done_inst); + runtime_simulator_->SimulateAsyncCopyLikeDone(copy_done_inst); EXPECT_EQ(copy_done_elapsed_time, 512); // There should be no request in the read/write queues. @@ -326,7 +399,7 @@ TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyFullBandwidth) { EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); } -TEST_F(SimulateAsyncCopyDoneTest, AsyncCopySharedBandwidth) { +TEST_F(SimulateAsyncCopyLikeDoneTest, AsyncCopySharedBandwidth) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { @@ -348,19 +421,20 @@ TEST_F(SimulateAsyncCopyDoneTest, AsyncCopySharedBandwidth) { // only use half bandwidth to access default memory. Thus, the elapsed time is // 32 * 4 / 0.5 = 256 float copy_done_2_elapsed_time = - runtime_simulator_->SimulateAsyncCopyDone(copy_done_2_inst); + runtime_simulator_->SimulateAsyncCopyLikeDone(copy_done_2_inst); EXPECT_EQ(copy_done_2_elapsed_time, 256); // The only write request (copy-start.2) should be completed. EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); // The read request has (128-32)*4 bytes left to process. - EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), - ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ - copy_start_1_inst, 384}})); + EXPECT_THAT( + runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopyLike{ + copy_start_1_inst, 384}})); } -TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyTransferPartialProcess) { +TEST_F(SimulateAsyncCopyLikeDoneTest, AsyncCopyTransferPartialProcess) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { @@ -381,7 +455,7 @@ TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyTransferPartialProcess) { // Execute copy-done.2. float copy_done_2_elapsed_time = - runtime_simulator_->SimulateAsyncCopyDone(copy_done_2_inst); + runtime_simulator_->SimulateAsyncCopyLikeDone(copy_done_2_inst); // For copy-done.2, it requires to transfer 32*4 bytes // default-write request. At the same time, there is a 128*4 bytes // default-read request in the queue for copy-start.1. So the @@ -389,14 +463,15 @@ TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyTransferPartialProcess) { EXPECT_EQ(copy_done_2_elapsed_time, 256); // In parallel with copy-done.2, copy-start.1 is also being processed. // So the remaining bytes should be 128*4 - 32*4 = 384. - EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), - ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ - copy_start_1_inst, 384}})); + EXPECT_THAT( + runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopyLike{ + copy_start_1_inst, 384}})); EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); // Execute copy-done.1. float copy_done_1_elapsed_time = - runtime_simulator_->SimulateAsyncCopyDone(copy_done_1_inst); + runtime_simulator_->SimulateAsyncCopyLikeDone(copy_done_1_inst); // The copy-done.1 is the only request in the read-queue, and there is no // request in the write-queue. Thus, it can use the full bandwidth. The // elapsed time is 384 / 1 = 384. @@ -406,7 +481,7 @@ TEST_F(SimulateAsyncCopyDoneTest, AsyncCopyTransferPartialProcess) { EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); } -TEST_F(SimulateAsyncCopyDoneTest, +TEST_F(SimulateAsyncCopyLikeDoneTest, SimulateComputeInstructionWithSingleAsyncCopy) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true @@ -431,14 +506,15 @@ TEST_F(SimulateAsyncCopyDoneTest, // requires 32 and 256 secs respectively. Thus, it is default memory access // dominated, which does not have idle time to process the async copy. EXPECT_EQ(compute_elapsed_time, 256); - EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), - ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ - copy_start_1_inst, 512}})); + EXPECT_THAT( + runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopyLike{ + copy_start_1_inst, 512}})); EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); } -TEST_F(SimulateAsyncCopyDoneTest, +TEST_F(SimulateAsyncCopyLikeDoneTest, SimulateComputeInstructionWithSharedBandwidth) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true @@ -469,16 +545,19 @@ TEST_F(SimulateAsyncCopyDoneTest, // 64 secs for alternate memory access + 128 secs for default memory access EXPECT_EQ(compute_elapsed_time, 192); - EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), - ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ - copy_start_1_inst, 480}})); + EXPECT_THAT( + runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopyLike{ + copy_start_1_inst, 480}})); - EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), - ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ - copy_start_2_inst, 96}})); + EXPECT_THAT( + runtime_simulator_->GetOutstandingWriteDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopyLike{ + copy_start_2_inst, 96}})); } -TEST_F(SimulateAsyncCopyDoneTest, SimulateComputeInstructionWithFullBandwidth) { +TEST_F(SimulateAsyncCopyLikeDoneTest, + SimulateComputeInstructionWithFullBandwidth) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { @@ -504,13 +583,15 @@ TEST_F(SimulateAsyncCopyDoneTest, SimulateComputeInstructionWithFullBandwidth) { // 64 secs for alternate memory access + 128 secs for default memory access EXPECT_EQ(compute_elapsed_time, 192); - EXPECT_THAT(runtime_simulator_->GetOutstandingReadDefaultQueue(), - ElementsAreArray({memory_space_assignment::OutstandingAsyncCopy{ - copy_start_1_inst, 448}})); + EXPECT_THAT( + runtime_simulator_->GetOutstandingReadDefaultQueue(), + ElementsAreArray({memory_space_assignment::OutstandingAsyncCopyLike{ + copy_start_1_inst, 448}})); EXPECT_THAT(runtime_simulator_->GetOutstandingWriteDefaultQueue(), IsEmpty()); } -TEST_F(SimulateAsyncCopyDoneTest, SimulateComputeInstructionWithEmptyQueues) { +TEST_F(SimulateAsyncCopyLikeDoneTest, + SimulateComputeInstructionWithEmptyQueues) { absl::string_view hlo_string = R"(HloModule module, is_scheduled=true ENTRY Entry { From 25b6510062cffd3a859070ab71a30c5c9bd779fe Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Thu, 25 Jul 2024 17:40:23 -0700 Subject: [PATCH 178/376] [XLA:MSA] Add comments to MsaAlgorithm class indicating its relationship with HeapSimulator class and how buffer_intervals_ are populated. PiperOrigin-RevId: 656159730 --- xla/service/memory_space_assignment/algorithm.cc | 4 ++++ xla/service/memory_space_assignment/algorithm.h | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index d18a69bc5635dd..4611453fef48ac 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -1527,6 +1527,10 @@ void MsaAlgorithm::CreateAllocationValuesForJointProcessedIntervals( } absl::StatusOr> MsaAlgorithm::Finish() { + // Note: Memory Space Assignment creates a HeapSimulator and passes an + // MsaAlgorithm object to it. buffer_intervals_ is populated by calling the + // Alloc(), Free() and ShareWith() methods on the MsaAlgorithm object in + // HeapSimulator. if (options_.autotuning_config.has_value()) { CHECK_EQ((*options_.autotuning_config).size(), buffer_intervals_.size()); } diff --git a/xla/service/memory_space_assignment/algorithm.h b/xla/service/memory_space_assignment/algorithm.h index d8225e65b3d7af..5e2073bcc183ec 100644 --- a/xla/service/memory_space_assignment/algorithm.h +++ b/xla/service/memory_space_assignment/algorithm.h @@ -372,6 +372,13 @@ class AsynchronousCopyResource { // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of // maximum size. +// +// Note: Memory space assignment (MSA) creates an MsaAlgorithm object and passes +// it to the HeapSimulator. The HeapSimulator calls Alloc(), Free() and +// ShareWith() on the MsaAlgorithm object to create buffer intervals (populate +// buffer_intervals_), these methods are inherited from +// GlobalDecreasingSizeBestFitHeap. The HeapSimulator finally calls the Finish() +// method which is overridden in this class. class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { public: using HloPositionOrUse = std::variant; From e54c144e5e67d57d397d89d9dd300e6f9222ef05 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Jul 2024 17:56:29 -0700 Subject: [PATCH 179/376] Remove unnecessary parentheses to fix build breakage. PiperOrigin-RevId: 656164228 --- xla/stream_executor/cuda/cuda_driver.cc | 138 ++++++++++++------------ 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index 763408e4c35e8b..866c1ff7131462 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -503,8 +503,8 @@ absl::Status GpuDriver::GraphInstantiate(CUgraphExec* exec, CUgraph graph, return cuda::ToStatus(cuGraphInstantiate(exec, graph, cu_flags), "Failed to instantiate CUDA graph"); #else - return (cuda::ToStatus(cuGraphInstantiate(exec, graph, nullptr, nullptr, 0), - "Failed to instantiate CUDA graph"); + return cuda::ToStatus(cuGraphInstantiate(exec, graph, nullptr, nullptr, 0), + "Failed to instantiate CUDA graph"); #endif // CUDA_VERSION >= 12000 } @@ -1221,75 +1221,75 @@ absl::Status GpuDriver::LoadPtx(GpuContext* context, const char* ptx_contents, CUmodule* module) { absl::Notification notification; absl::Status ret = absl::OkStatus(); - GetDriverExecutor()->Schedule([context, ptx_contents, module, &ret, - ¬ification]() { - ScopedActivateContext activation(context); - void* ptx_data = const_cast(ptx_contents); - static const unsigned int kLogBufferBytesLimit = 1024; - unsigned int error_log_buffer_bytes = kLogBufferBytesLimit; - unsigned int info_log_buffer_bytes = kLogBufferBytesLimit; - absl::InlinedVector error_log_buffer(error_log_buffer_bytes); - absl::InlinedVector info_log_buffer(info_log_buffer_bytes); - bool log_verbose = true; - CUjit_option options[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, - CU_JIT_ERROR_LOG_BUFFER, - CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, - CU_JIT_INFO_LOG_BUFFER, CU_JIT_LOG_VERBOSE}; - // Note that the driver API wants the contents of this values to be stored - // in an array of void*s, so we coerce them accordingly. - void* option_values[] = { - absl::bit_cast(uintptr_t(error_log_buffer_bytes)), - absl::bit_cast(error_log_buffer.data()), - absl::bit_cast(uintptr_t(info_log_buffer_bytes)), - absl::bit_cast(info_log_buffer.data()), - absl::bit_cast(uintptr_t(log_verbose))}; - CHECK(TF_ARRAYSIZE(options) == TF_ARRAYSIZE(option_values)); - - absl::Status status; - { - // TODO(leary) Need to see if NVIDIA can expunge the leakiness in their - // module loading: see http://b/13248943 - absl::LeakCheckDisabler disabler; - status = cuda::ToStatus(cuModuleLoadDataEx( - module, ptx_data, TF_ARRAYSIZE(options), options, option_values)); - } + GetDriverExecutor()->Schedule( + [context, ptx_contents, module, &ret, ¬ification]() { + ScopedActivateContext activation(context); + void* ptx_data = const_cast(ptx_contents); + static const unsigned int kLogBufferBytesLimit = 1024; + unsigned int error_log_buffer_bytes = kLogBufferBytesLimit; + unsigned int info_log_buffer_bytes = kLogBufferBytesLimit; + absl::InlinedVector error_log_buffer(error_log_buffer_bytes); + absl::InlinedVector info_log_buffer(info_log_buffer_bytes); + bool log_verbose = true; + CUjit_option options[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_INFO_LOG_BUFFER, CU_JIT_LOG_VERBOSE}; + // Note that the driver API wants the contents of this values to be + // stored in an array of void*s, so we coerce them accordingly. + void* option_values[] = { + absl::bit_cast(uintptr_t(error_log_buffer_bytes)), + absl::bit_cast(error_log_buffer.data()), + absl::bit_cast(uintptr_t(info_log_buffer_bytes)), + absl::bit_cast(info_log_buffer.data()), + absl::bit_cast(uintptr_t(log_verbose))}; + CHECK(TF_ARRAYSIZE(options) == TF_ARRAYSIZE(option_values)); + + absl::Status status; + { + // TODO(leary) Need to see if NVIDIA can expunge the leakiness in + // their module loading: see http://b/13248943 + absl::LeakCheckDisabler disabler; + status = cuda::ToStatus(cuModuleLoadDataEx( + module, ptx_data, TF_ARRAYSIZE(options), options, option_values)); + } - // The PTX JIT mutates the values in the option values array to reflect the - // size of the logs it output; now that we've made the call, read the values - // back out. - error_log_buffer_bytes = reinterpret_cast(option_values[0]); - info_log_buffer_bytes = reinterpret_cast(option_values[2]); - CHECK_LE(error_log_buffer_bytes, kLogBufferBytesLimit); - CHECK_LE(info_log_buffer_bytes, kLogBufferBytesLimit); - - if (!status.ok()) { - LOG(ERROR) << "failed to load PTX text as a module: " << status; - // As a precaution for null termination of the API-provided value, ensure - // that at least the last byte is null. - error_log_buffer[error_log_buffer_bytes ? error_log_buffer_bytes - 1 - : 0] = '\0'; - LOG(ERROR) << "error log buffer (" << error_log_buffer_bytes - << " bytes): " << error_log_buffer.data(); - if (absl::StrContains(error_log_buffer.data(), - "Register allocation failed")) { - ret = absl::ResourceExhaustedError( - absl::StrFormat("Failed to load PTX text as a module (register " - "allocation failed): %s", - status.ToString())); - } else { - ret = status; - } - notification.Notify(); - return; - } + // The PTX JIT mutates the values in the option values array to reflect + // the size of the logs it output; now that we've made the call, read + // the values back out. + error_log_buffer_bytes = reinterpret_cast(option_values[0]); + info_log_buffer_bytes = reinterpret_cast(option_values[2]); + CHECK_LE(error_log_buffer_bytes, kLogBufferBytesLimit); + CHECK_LE(info_log_buffer_bytes, kLogBufferBytesLimit); + + if (!status.ok()) { + LOG(ERROR) << "failed to load PTX text as a module: " << status; + // As a precaution for null termination of the API-provided value, + // ensure that at least the last byte is null. + error_log_buffer[error_log_buffer_bytes ? error_log_buffer_bytes - 1 + : 0] = '\0'; + LOG(ERROR) << "error log buffer (" << error_log_buffer_bytes + << " bytes): " << error_log_buffer.data(); + if (absl::StrContains(error_log_buffer.data(), + "Register allocation failed")) { + ret = absl::ResourceExhaustedError( + absl::StrFormat("Failed to load PTX text as a module (register " + "allocation failed): %s", + status.ToString())); + } else { + ret = status; + } + notification.Notify(); + return; + } - VLOG(3) << "PTX compilation info log (" << info_log_buffer_bytes - << " bytes): " << info_log_buffer.data(); - VLOG(3) << "PTX compilation error log (" << error_log_buffer_bytes - << " bytes): " << error_log_buffer.data(); - CHECK(module != nullptr); - notification.Notify(); - }); + VLOG(3) << "PTX compilation info log (" << info_log_buffer_bytes + << " bytes): " << info_log_buffer.data(); + VLOG(3) << "PTX compilation error log (" << error_log_buffer_bytes + << " bytes): " << error_log_buffer.data(); + CHECK(module != nullptr); + notification.Notify(); + }); notification.WaitForNotification(); return ret; From f4c38381892367ced101d58c580cc65d5d651eef Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Thu, 25 Jul 2024 18:19:19 -0700 Subject: [PATCH 180/376] Add IsSubnormal utility function for exhaustive tests PiperOrigin-RevId: 656172525 --- .../exhaustive/exhaustive_op_test_utils.cc | 14 ++++++++++ .../exhaustive/exhaustive_op_test_utils.h | 27 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/xla/tests/exhaustive/exhaustive_op_test_utils.cc index 0ab2e679006ec4..4f606a11dc0220 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -39,6 +39,20 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { +bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); } + +bool IsSubnormalReal(xla::complex128 value) { + return IsSubnormal(value.real()); +} + +bool IsSubnormalImaginary(xla::complex64 value) { + return IsSubnormal(value.imag()); +} + +bool IsSubnormalImaginary(xla::complex128 value) { + return IsSubnormal(value.imag()); +} + // For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of // precision to be guaranteed that we're printing the full number. // diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.h b/xla/tests/exhaustive/exhaustive_op_test_utils.h index 2a2cf913be4461..80c69703dfb96c 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -55,6 +55,33 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { +// Determines if the real component of the complex number is subnormal. +// +// See also IsSubnormal to check if either component is subnormal. +bool IsSubnormalReal(xla::complex64); +bool IsSubnormalReal(xla::complex128); + +// Determines if the imaginary component of the complex number is subnormal. +// +// See also IsSubnormal to check if either component is subnormal. +bool IsSubnormalImaginary(xla::complex64); +bool IsSubnormalImaginary(xla::complex128); + +// Determines if the NativeT is subnormal. +// +// For complex numbers, this will return true if either real or imaginary +// component is subnormal. See IsSubnormalReal and IsSubnormalImaginary if you +// only care about one component. +template +bool IsSubnormal(NativeT value) { + if constexpr (std::is_same_v || + std::is_same_v) { + return IsSubnormalReal(value) || IsSubnormalImaginary(value); + } else { + return std::fpclassify(value) == FP_SUBNORMAL; + } +} + struct ErrorSpec { double abs_err = 0; double rel_err = 0; From 18c2fe7dd1ffa7416f54472280461e6dc8e26b36 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Thu, 25 Jul 2024 19:09:49 -0700 Subject: [PATCH 181/376] [XLA:UNSTACKER] Changes to unstacker to improve performance. For DSFusionPattern and NestedDSFusionPattern: 1) Bring out the dynamic-slice inside the unstacking fusion computation. 2) Convert dynamic-slice instructions to slices (this is done to allow MSA to convert then to async-slice later) 3) Keep the bitcast fusion right next to the user inside the loop PiperOrigin-RevId: 656188199 --- xla/service/hlo_unstacker.cc | 231 ++++++++++++++++++++++-------- xla/service/hlo_unstacker_test.cc | 31 ++++ 2 files changed, 203 insertions(+), 59 deletions(-) diff --git a/xla/service/hlo_unstacker.cc b/xla/service/hlo_unstacker.cc index 9ee40f71fba79c..cb969c56357845 100644 --- a/xla/service/hlo_unstacker.cc +++ b/xla/service/hlo_unstacker.cc @@ -51,9 +51,28 @@ limitations under the License. namespace xla { namespace { +// TODO: b/352400145 - Unify the patterns, handlers and their type into a class +// or struct. +enum class PatternType { + DSFusionPattern, + NestedDSFusionPattern, + Other, +}; + +static std::string PatternTypeToString(PatternType pattern_type) { + switch (pattern_type) { + case PatternType::DSFusionPattern: + return "DSFusionPattern"; + case PatternType::NestedDSFusionPattern: + return "NestedDSFusionPattern"; + case PatternType::Other: + return "Other"; + } +} + // Holds the information about custom unstacking patterns. struct PatternInfo { - std::string name; + PatternType type; std::vector unstacked_instrs; const HloInstruction* instr; Shape unstacked_shape; @@ -61,12 +80,12 @@ struct PatternInfo { std::string ToString() const { if (unstacking_computation == nullptr) { - return absl::StrCat("name: \n\t", name, "\n", "instr: \n\t", - instr->name(), "\n", "shape: \n\t", + return absl::StrCat("type: \n\t", PatternTypeToString(type), "\n", + "instr: \n\t", instr->name(), "\n", "shape: \n\t", unstacked_shape.ToString(true)); } else { - return absl::StrCat("name: \n\t", name, "\n", "instr: \n\t", - instr->name(), "\n", "shape: \n\t", + return absl::StrCat("type: \n\t", PatternTypeToString(type), "\n", + "instr: \n\t", instr->name(), "\n", "shape: \n\t", unstacked_shape.ToString(true), "\n", "comp: \n", unstacking_computation->name()); } @@ -136,6 +155,7 @@ class UnstackerTransformer { continue; } PatternInfo& pattern_info = stacked_user.value(); + pattern_type_ = pattern_info.type; VLOG(3) << "PatternInfo:" << "\n" << pattern_info.ToString(); if (pattern_info.unstacking_computation != nullptr && @@ -192,7 +212,7 @@ class UnstackerTransformer { return unstacking_computation_; } - std::vector>& + std::vector>& GetLoopChanges() { return loop_changes_; } @@ -211,11 +231,14 @@ class UnstackerTransformer { } void AddLoopChange( - std::function loop_change) { + std::function loop_change) { loop_changes_.push_back(loop_change); } + PatternType GetPatternType() const { return pattern_type_; } + private: + PatternType pattern_type_; const UnstackerMetadata& metadata_; // This pointer is populated if the unstacker finds unstackable loop input. std::unique_ptr unstacked_shape_ = nullptr; @@ -226,7 +249,7 @@ class UnstackerTransformer { // A vector of lambdas that describe necessary changes to the shape of the // loops to unstack. The lambdas accept the pointer to the new unstacked // shape. - std::vector> loop_changes_; + std::vector> loop_changes_; // a list of lambdas that captures all the changes to the hlo graph needed for // unstacking. std::vector> body_changes_; @@ -375,11 +398,12 @@ bool CanPropagateGteShapeChangesInComputation( // This function is responsible for: // 1. Hoisting the unstacking computation outside the while_instr. // 2. Replacing the input of the while_instr with the new unstacked version. -void UnstackWhileInput(HloComputation* unstacking_computation, - HloInstruction* while_instr, const Shape* new_shape, - int64_t index) { +void UnstackWhileInput(const UnstackerTransformer& unstacker, + HloInstruction* while_instr, int64_t index) { VLOG(3) << "Unstacking while input: " << while_instr->name() << " at " << index; + const Shape* new_shape = unstacker.GetUnstackedShape(); + HloComputation* unstacking_computation = unstacker.GetUnstackingComputation(); const Shape& slice_shape = new_shape->tuple_shapes(0); HloInstruction* old_while_input = while_instr->while_init()->mutable_operand(index); @@ -406,17 +430,53 @@ void UnstackWhileInput(HloComputation* unstacking_computation, // Hoist the unstacking computation outside the while_instr and create a // tuple of slices. for (int64_t i = 0; i < new_shape->tuple_shapes_size(); ++i) { - std::vector operands = { - old_while_input, - while_instr->AddInstruction(MakeScalarConstantWithShape( - unstacking_computation->parameter_instruction(1)->shape(), i))}; - HloInstruction* slice = - while_instr->AddInstruction(HloInstruction::CreateFusion( - slice_shape, HloInstruction::FusionKind::kLoop, operands, - while_instr->GetModule()->AddEmbeddedComputation( - unstacking_computation->Clone()), - "hoisted")); - slices.push_back(slice); + HloInstruction* root_instr = unstacking_computation->root_instruction(); + // TODO: b/352400145 - After unifying patterns and handlers, instead of + // using the pattern type to determine the unstacked input, we should use + // the pattern object to call the appropriate method. + // + // For DSFusionPattern and NestedDSFusionPattern, we rewrite the + // dynamic-slice as a slice instruction in the hope that these slices are + // later prefetched using async-slice by MSA. For other patterns, we + // resort to the original unstacking computation until we find benefit in + // doing otherwise. + if (unstacker.GetPatternType() == PatternType::DSFusionPattern || + unstacker.GetPatternType() == PatternType::NestedDSFusionPattern) { + HloInstruction* dynamic_slice = root_instr->mutable_operand(0); + std::vector new_start_indices; + new_start_indices.reserve(dynamic_slice->shape().rank()); + std::vector new_limit_indices; + new_limit_indices.reserve(dynamic_slice->shape().rank()); + std::vector new_strides; + new_strides.reserve(dynamic_slice->shape().rank()); + new_start_indices.push_back(i); + new_limit_indices.push_back(i + 1); + new_strides.push_back(1); + for (int64_t j = 1; j < dynamic_slice->shape().rank(); ++j) { + new_start_indices.push_back(0); + new_limit_indices.push_back( + dynamic_slice->mutable_operand(0)->shape().dimensions(j)); + new_strides.push_back(1); + } + HloInstruction* slice = + while_instr->AddInstruction(HloInstruction::CreateSlice( + dynamic_slice->shape(), old_while_input, new_start_indices, + new_limit_indices, new_strides)); + + slices.push_back(slice); + } else { + std::vector operands = { + old_while_input, + while_instr->AddInstruction(MakeScalarConstantWithShape( + unstacking_computation->parameter_instruction(1)->shape(), i))}; + HloInstruction* slice = + while_instr->AddInstruction(HloInstruction::CreateFusion( + slice_shape, HloInstruction::FusionKind::kLoop, operands, + while_instr->GetModule()->AddEmbeddedComputation( + unstacking_computation->Clone()), + "hoisted")); + slices.push_back(slice); + } } } HloInstruction* new_operand_element = @@ -474,28 +534,27 @@ bool CanUnstackWhileOperand(const HloInstruction* while_instr, } } - auto loop_change = [=](HloInstruction* loop, const Shape* new_shape, - HloComputation* unstacking_computation, - int64_t idx) mutable { + auto loop_change = [=](const UnstackerTransformer& unstacker, + HloInstruction* loop, int64_t idx) mutable { Shape old_shape = ShapeUtil::MakeStaticShape( loop->while_body()->parameter_instruction(0)->shape()); - ShapeUtil::UpdateTupleShape(*new_shape, idx, &old_shape); + ShapeUtil::UpdateTupleShape(*unstacker.GetUnstackedShape(), idx, + &old_shape); loop->while_body()->ReplaceParameter( 0, HloInstruction::CreateParameter(0, old_shape, "unstacked")); loop->while_condition()->ReplaceParameter( 0, HloInstruction::CreateParameter(0, old_shape, "unstacked")); - CHECK_NE(unstacking_computation, nullptr); - UnstackWhileInput(unstacking_computation, loop, new_shape, idx); + CHECK_NE(unstacker.GetUnstackingComputation(), nullptr); + UnstackWhileInput(unstacker, loop, idx); // Update the input and output shape of the loop. *loop->mutable_shape() = old_shape; }; - auto loop_change_wrapper = [&loop_change, while_instr, index]( - const Shape* new_shape, - HloComputation* unstacking_computation) { + auto loop_change_wrapper = [&loop_change, while_instr, + index](const UnstackerTransformer& unstacker) { HloInstruction* mutable_loop = const_cast(while_instr); - loop_change(mutable_loop, new_shape, unstacking_computation, index); + loop_change(unstacker, mutable_loop, index); }; unstacker.AddLoopChange(loop_change_wrapper); return true; @@ -567,8 +626,7 @@ bool UnstackWhileOperandAtIndex( } // Apply the changes to the shape of the loop body and condition computations. for (auto& loop_change : unstacker.GetLoopChanges()) { - loop_change(unstacker.GetUnstackedShape(), - unstacker.GetUnstackingComputation()); + loop_change(unstacker); } for (const HloInstruction* instr : unstacker.GetUnstackedInstructions()) { unstacked_instructions.push_back(instr); @@ -649,9 +707,10 @@ std::optional GetDSFusionPattern(const UnstackerMetadata& metadata, match::Bitcast(match::Op(&bitcast_operand)))) { if (bitcast_operand == shape_covering_instr) { PatternInfo pattern_info; - pattern_info.name = "DSFusionPattern"; + pattern_info.type = PatternType::DSFusionPattern; pattern_info.instr = instr; - const Shape& slice_shape = instr->shape(); + // const Shape& slice_shape = instr->shape(); + const Shape& slice_shape = shape_covering_instr->shape(); const int64_t num_layers = instr->operand(0)->shape().dimensions(0); pattern_info.unstacked_shape = MakeUnstackedShapeFromSlice(slice_shape, num_layers); @@ -674,8 +733,18 @@ absl::Status UnstackDSFusionPattern( HloInstruction* new_operand = parent_loop->AddInstruction(HloInstruction::CreateCustomCall( slice_shape, {stacked, offset}, "DynamicGte")); + + HloInstruction* bitcast = mutable_dynamic_slicing_fusion->AddInstruction( + HloInstruction::CreateBitcast(mutable_dynamic_slicing_fusion->shape(), + new_operand)); + HloInstruction* bitcast_fusion = + mutable_dynamic_slicing_fusion->AddInstruction( + HloInstruction::CreateFusion(mutable_dynamic_slicing_fusion->shape(), + HloInstruction::FusionKind::kLoop, + bitcast)); + return mutable_dynamic_slicing_fusion->ReplaceAllUsesWithDifferentShape( - new_operand); + bitcast_fusion); } // This function recognizes fusions with the following pattern: @@ -703,7 +772,7 @@ std::optional GetDUSFusionPattern( if (shape_covering_instr->parent()->root_instruction() == shape_covering_instr) { PatternInfo pattern_info; - pattern_info.name = "DUSFusionPattern"; + pattern_info.type = PatternType::Other; pattern_info.instr = instr; pattern_info.unstacked_shape = MakeUnstackedShapeFromSlice( instr->operand(2)->shape(), instr->operand(0)->shape().dimensions(0)); @@ -767,7 +836,7 @@ std::optional GetDUSFusionWithPadPattern( const HloInstruction* pad_instr = shape_covering_instr->operand(1)->operand(0); PatternInfo pattern_info; - pattern_info.name = "DUSFusionWithPadPattern"; + pattern_info.type = PatternType::Other; pattern_info.instr = instr; pattern_info.unstacked_shape = MakeUnstackedShapeFromSlice( pad_instr->shape(), @@ -851,7 +920,7 @@ std::optional GetDSFusionWithAddPattern( if (add_operand == shape_covering_instr) { const int64_t num_layers = instr->operand(0)->shape().dimensions(0); PatternInfo pattern_info; - pattern_info.name = "DUSFusionWithAddPattern"; + pattern_info.type = PatternType::Other; pattern_info.instr = instr; pattern_info.unstacked_shape = MakeUnstackedShapeFromSlice(instr->shape(), num_layers); @@ -925,7 +994,7 @@ absl::Status UnstackDSFusionWithAddPattern( // (We assume that the stacked parameter is always the first operand and // the slicing offset is the second operand.) // 4. The fusion user contains a shape-covering dynamic-slice instruction. -std::optional GetNestedDUSFusionPattern( +std::optional GetNestedDSFusionPattern( const UnstackerMetadata& metadata, const HloInstruction* instr, int64_t stacked_operand_idx) { if (instr->opcode() != HloOpcode::kFusion) { @@ -975,10 +1044,10 @@ std::optional GetNestedDUSFusionPattern( const int64_t num_layers = inner_fusion_user->operand(0)->shape().dimensions(0); PatternInfo pattern_info; - pattern_info.name = "NestedDUSFusionPattern"; + pattern_info.type = PatternType::NestedDSFusionPattern; pattern_info.instr = inner_fusion_user; pattern_info.unstacked_shape = - MakeUnstackedShapeFromSlice(inner_fusion_user->shape(), num_layers); + MakeUnstackedShapeFromSlice(inner_fusion_instr->shape(), num_layers); pattern_info.unstacking_computation = inner_fusion_user->fused_instructions_computation(); pattern_info.unstacked_instrs.push_back(inner_fusion_user); @@ -989,8 +1058,8 @@ std::optional GetNestedDUSFusionPattern( } // The function below captures all the changes necessary to hlo graph for it's -// corresponding (IsNestedDynamicSlicingFusion) pattern to unstack. -absl::Status UnstackNestedDUSFusionPattern( +// corresponding (GetNestedDSFusionPattern) pattern to unstack. +absl::Status UnstackNestedDSFusionPattern( HloInstruction* mutable_dynamic_slicing_fusion, const Shape& slice_shape) { // We are sure that this lambda is called with a nested fusion. HloInstruction* parent_fusion = @@ -1018,9 +1087,16 @@ absl::Status UnstackNestedDUSFusionPattern( stacked_param_number, HloInstruction::CreateParameter(stacked_param_number, slice_shape, "sliced")); + HloInstruction* bitcast = mutable_dynamic_slicing_fusion->AddInstruction( + HloInstruction::CreateBitcast(mutable_dynamic_slicing_fusion->shape(), + sliced_param)); + HloInstruction* bitcast_fusion = + mutable_dynamic_slicing_fusion->AddInstruction( + HloInstruction::CreateFusion(mutable_dynamic_slicing_fusion->shape(), + HloInstruction::FusionKind::kLoop, + bitcast)); TF_RETURN_IF_ERROR( - mutable_dynamic_slicing_fusion->ReplaceAllUsesWith(sliced_param)); - + mutable_dynamic_slicing_fusion->ReplaceAllUsesWith(bitcast_fusion)); // Create the custom-call to dynamically get the tuple element given the // loop iteration number. We rely on WhileLoopUnroller to rewrite this as // a get-tuple-element hlo once the iteration number is known and loop @@ -1051,19 +1127,35 @@ std::optional GetDSAndDUSPattern(const UnstackerMetadata& metadata, return std::nullopt; } - std::optional ds_pattern_info = - GetDSFusionPattern(metadata, instr, instr->operand_index(stacked)); - - if (!ds_pattern_info.has_value()) { + HloInstruction* shape_covering_ds_instr = + GetMostMajorShapeCoveringDynamicIndexInFusion( + metadata, instr, HloOpcode::kDynamicSlice, 2, stacked_operand_idx); + if (shape_covering_ds_instr == nullptr) { + return std::nullopt; + } + HloInstruction* bitcast_operand = nullptr; + if (!Match(instr->fused_instructions_computation()->root_instruction(), + match::Bitcast(match::Op(&bitcast_operand)))) { + return std::nullopt; + } + if (bitcast_operand != shape_covering_ds_instr) { return std::nullopt; } if (!GetDUSFusionPattern(metadata, stacked->users()[1], stacked->users()[1]->operand_index(stacked))) { return std::nullopt; } - ds_pattern_info->name = "DSAndDUSPattern"; - ds_pattern_info->unstacked_instrs.push_back(stacked->users()[1]); - return ds_pattern_info; + PatternInfo pattern_info; + pattern_info.type = PatternType::Other; + pattern_info.instr = instr; + const Shape& slice_shape = instr->shape(); + const int64_t num_layers = instr->operand(0)->shape().dimensions(0); + pattern_info.unstacked_shape = + MakeUnstackedShapeFromSlice(slice_shape, num_layers); + pattern_info.unstacking_computation = instr->fused_instructions_computation(); + pattern_info.unstacked_instrs.push_back(instr); + pattern_info.unstacked_instrs.push_back(stacked->users()[1]); + return pattern_info; } absl::Status UnstackDSAndDUSPattern(HloInstruction* mutable_dynamic_slice, @@ -1073,8 +1165,16 @@ absl::Status UnstackDSAndDUSPattern(HloInstruction* mutable_dynamic_slice, HloComputation* parent = stacked_gte->parent(); ShapeUtil::UpdateTupleShape(stacked_gte->shape(), stacked_gte_index, parent->root_instruction()->mutable_shape()); + + HloComputation* parent_loop = mutable_dynamic_slice->parent(); + HloInstruction* stacked = mutable_dynamic_slice->mutable_operand(0); + HloInstruction* offset = mutable_dynamic_slice->mutable_operand(1); + HloInstruction* new_operand = + parent_loop->AddInstruction(HloInstruction::CreateCustomCall( + slice_shape, {stacked, offset}, "DynamicGte")); TF_RETURN_IF_ERROR( - UnstackDSFusionPattern(mutable_dynamic_slice, slice_shape)); + mutable_dynamic_slice->ReplaceAllUsesWithDifferentShape(new_operand)); + HloInstruction* mutable_dynamic_update_slice = stacked_gte->users()[1]; TF_RETURN_IF_ERROR( UnstackDUSFusionPattern(mutable_dynamic_update_slice, slice_shape)); @@ -1108,7 +1208,7 @@ std::optional GetReduceFusionPattern( match::Add(match::Parameter(), match::Parameter()))) { if (reduce_operand == shape_covering_instr) { PatternInfo pattern_info; - pattern_info.name = "ReduceFusion"; + pattern_info.type = PatternType::Other; pattern_info.instr = instr; const Shape& slice_shape = instr->shape(); const int64_t num_layers = instr->operand(0)->shape().dimensions(0); @@ -1124,6 +1224,19 @@ std::optional GetReduceFusionPattern( return std::nullopt; } +absl::Status UnstackReduceFusionPattern(HloInstruction* mutable_reduce_fusion, + const Shape& slice_shape) { + HloComputation* parent_loop = mutable_reduce_fusion->parent(); + + HloInstruction* stacked = mutable_reduce_fusion->mutable_operand(0); + HloInstruction* offset = mutable_reduce_fusion->mutable_operand(1); + + HloInstruction* new_operand = + parent_loop->AddInstruction(HloInstruction::CreateCustomCall( + slice_shape, {stacked, offset}, "DynamicGte")); + return mutable_reduce_fusion->ReplaceAllUsesWithDifferentShape(new_operand); +} + }; // namespace // The entry point of the unstacking algorithm. Given a module, it creates the @@ -1151,9 +1264,9 @@ absl::StatusOr HloUnstacker::Run( metadata.custom_handlers.push_back( std::make_pair(GetDSFusionWithAddPattern, UnstackDSFusionWithAddPattern)); metadata.custom_handlers.push_back( - std::make_pair(GetReduceFusionPattern, UnstackDSFusionPattern)); + std::make_pair(GetReduceFusionPattern, UnstackReduceFusionPattern)); metadata.custom_handlers.push_back( - std::make_pair(GetNestedDUSFusionPattern, UnstackNestedDUSFusionPattern)); + std::make_pair(GetNestedDSFusionPattern, UnstackNestedDSFusionPattern)); std::vector entry_loops; for (HloInstruction* instr : @@ -1204,7 +1317,7 @@ absl::StatusOr HloUnstacker::Run( TF_ASSIGN_OR_RETURN( bool unrolled, WhileLoopUnroller::Unroll(loop, /*unroll_factor=*/-1, - /*wrap_in_trivial_loop=*/true, + /*wrap_in_trivial_loop=*/false, /*force_unroll=*/true, /*prepare=*/false)); CHECK(unrolled); } diff --git a/xla/service/hlo_unstacker_test.cc b/xla/service/hlo_unstacker_test.cc index e39fe3c67a113a..d3d6eb06ba9754 100644 --- a/xla/service/hlo_unstacker_test.cc +++ b/xla/service/hlo_unstacker_test.cc @@ -15,13 +15,17 @@ limitations under the License. #include "xla/service/hlo_unstacker.h" +#include #include #include #include #include + #include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -30,6 +34,17 @@ namespace { using UnstackerTest = HloTestBase; +int64_t GetSliceCountInEntry(HloModule* module) { + int64_t slice_instrs_count = 0; + for (HloInstruction* instr : + module->entry_computation()->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kSlice) { + slice_instrs_count++; + } + } + return slice_instrs_count; +} + TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) { std::string hlo_string = R"( HloModule SimpleLoop @@ -74,6 +89,8 @@ TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt)); } @@ -234,6 +251,8 @@ TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUser) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -310,6 +329,9 @@ TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserMultipleIndex) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. For each unstacked input, we + // create 4 slices, 8 in total. + EXPECT_EQ(GetSliceCountInEntry(module.get()), 8); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -366,6 +388,8 @@ TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDiffereOperandsOrder) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -440,6 +464,8 @@ TEST_F(UnstackerTest, UnstackLoopMultipleNestedFusionUsersSameUnstackingComps) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -633,6 +659,9 @@ TEST_F(UnstackerTest, UnstackMultipleLoops) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. For each loop there is one + // unstacked input that creates 4 slices, in total 8 slices for two loops. + EXPECT_EQ(GetSliceCountInEntry(module.get()), 8); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } @@ -710,6 +739,8 @@ TEST_F(UnstackerTest, UnstackNestedLoopSingleNestedFusionUser) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetSliceCountInEntry(module.get()), 4); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } From 3046b927191b5f73a5577d2bcb9966cc8bd2a859 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 25 Jul 2024 19:11:36 -0700 Subject: [PATCH 182/376] [xla:cpu] Use DCHECK on a hot path in thunk_executor PiperOrigin-RevId: 656188976 --- xla/service/cpu/cpu_executable.cc | 4 ++-- xla/service/cpu/runtime/thunk_executor.cc | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/xla/service/cpu/cpu_executable.cc b/xla/service/cpu/cpu_executable.cc index ee843dc147eb0b..a37a40e9d19acd 100644 --- a/xla/service/cpu/cpu_executable.cc +++ b/xla/service/cpu/cpu_executable.cc @@ -82,7 +82,7 @@ FunctionRegistry::FunctionRegistry(SimpleOrcJIT* jit) : jit_(jit) {} absl::StatusOr FunctionRegistry::FindKernel( std::string_view name) { - VLOG(2) << "Find host kernel with a name " << name; + VLOG(3) << "Find host kernel with a name " << name; llvm::Expected sym = jit_->FindCompiledSymbol(std::string(name)); @@ -96,7 +96,7 @@ absl::StatusOr FunctionRegistry::FindKernel( absl::StatusOr FunctionRegistry::FindComparator( std::string_view name) { - VLOG(2) << "Find comparator with a name " << name; + VLOG(3) << "Find comparator with a name " << name; llvm::Expected sym = jit_->FindCompiledSymbol(std::string(name)); diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index 5879ff8f718488..4281442d5c4305 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -153,7 +153,8 @@ tsl::AsyncValueRef ThunkExecutor::Execute( // alive while thunk executor has pending tasks. auto execute_event = state->execute_event; execute_event.AndThen([state = std::move(state)] { - CHECK_EQ(state->pending_sink_nodes.load(std::memory_order_acquire), 0) + auto cnt = state->pending_sink_nodes.load(std::memory_order_acquire); + DCHECK_EQ(cnt, 0) << "All sink nodes must be completed before execute_event is marked " "available."; }); @@ -244,7 +245,7 @@ void ThunkExecutor::Execute(ExecuteState* state, ExecuteState::Node& node = state->nodes[id]; int64_t cnt = node.counter.load(std::memory_order_acquire); - CHECK_EQ(cnt, 0) << "Node counter must be 0"; // Crash Ok + DCHECK_EQ(cnt, 0) << "Node counter must be 0"; // Crash Ok // If we have multiple ready thunks, split the ready queue and offload // thunks processing to the task runner. From da7956817d1632e123f672687298c5f833b09785 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 25 Jul 2024 19:28:55 -0700 Subject: [PATCH 183/376] [xla:cpu] Don't forget to commit donation transaction on thunk execution error PiperOrigin-RevId: 656195503 --- xla/pjrt/cpu/cpu_client.cc | 25 ++++++++++++++----------- xla/pjrt/cpu/cpu_client_test.cc | 4 ++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index 02b6524018a852..a832ab568d3408 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -1555,7 +1555,9 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( tsl::port::ScopedFlushDenormal flush; tsl::port::ScopedSetRound round(FE_TONEAREST); - XlaCustomCallStatus status; + // Execution status for XLA:CPU "classic" runtime or thunks. + XlaCustomCallStatus compute_function_status; + tsl::AsyncValueRef thunks_execute_event; // Immediately allocate memory and prepare for computation. buffer_alloc.Allocate(); @@ -1574,8 +1576,8 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( if (cpu_executable->has_compute_function()) { // Call jit-compiled function implementing XLA executable. cpu_executable->compute_function()(result_buffer, &run_options, nullptr, - buffer_pointers.data(), &status, - nullptr); + buffer_pointers.data(), + &compute_function_status, nullptr); } else if (cpu_executable->has_thunks()) { // Call interpreted thunk sequence implementing XLA executable. @@ -1611,14 +1613,11 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( &collective_params, &custom_call_execute_params}; - auto thunks_execute_event = - cpu_executable->thunks().Execute(execute_params); + thunks_execute_event = cpu_executable->thunks().Execute(execute_params); tsl::profiler::TraceMe trace( "ThunkExecutor::Execute (wait for completion)"); tsl::BlockUntilReady(thunks_execute_event); - if (thunks_execute_event.IsError()) - return thunks_execute_event.GetError(); } else { return Internal("CpuExecutable has no compute function or thunks."); @@ -1628,10 +1627,14 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( std::move(donation_transaction).Commit(); } - std::optional error_message = - xla::CustomCallStatusGetMessage(&status); - if (error_message) { - return Internal("Generated function failed: %s", *error_message); + // Forward errors (if any) after executing compute function or thunks. + if (cpu_executable->has_compute_function()) { + if (auto error_message = + xla::CustomCallStatusGetMessage(&compute_function_status)) { + return Internal("Generated function failed: %s", *error_message); + } + } else if (thunks_execute_event.IsError()) { + return thunks_execute_event.GetError(); } } else { diff --git a/xla/pjrt/cpu/cpu_client_test.cc b/xla/pjrt/cpu/cpu_client_test.cc index a66a2901d4cbab..641222e91ce21a 100644 --- a/xla/pjrt/cpu/cpu_client_test.cc +++ b/xla/pjrt/cpu/cpu_client_test.cc @@ -134,13 +134,13 @@ ENTRY DonationWithExecutionError() -> f32[2, 2] { auto result = pjrt_executable->Execute(/*argument_handles=*/{{buffer.get()}}, /*options=*/{}); ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().message(), ::testing::HasSubstr("test error.")); + EXPECT_THAT(result.status().message(), HasSubstr("test error.")); result = pjrt_executable->Execute(/*argument_handles=*/{{buffer.get()}}, /*options=*/{}); ASSERT_FALSE(result.ok()); EXPECT_THAT(result.status().message(), - ::testing::HasSubstr("buffer has been deleted or donated.")); + HasSubstr("buffer has been deleted or donated.")); } TEST(TfrtCpuClientTest, HloSnapshot) { From 0099a46071a49dfd67a1951f3c4c203c3f2e7d25 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 25 Jul 2024 20:04:34 -0700 Subject: [PATCH 184/376] [xla:cpu] Add support for sorting 25 inputs PiperOrigin-RevId: 656208274 --- xla/service/cpu/runtime/sort_thunk.cc | 3 +++ xla/tests/sort_test.cc | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/xla/service/cpu/runtime/sort_thunk.cc b/xla/service/cpu/runtime/sort_thunk.cc index c8adf958f2a6e7..51d7c3ede0f624 100644 --- a/xla/service/cpu/runtime/sort_thunk.cc +++ b/xla/service/cpu/runtime/sort_thunk.cc @@ -462,6 +462,9 @@ static absl::Status SortInplace(absl::Span data, case 16: sort(std::integral_constant{}); break; + case 25: + sort(std::integral_constant{}); + break; default: return Internal("Unsupported number of sorted inputs: %d", data.size()); } diff --git a/xla/tests/sort_test.cc b/xla/tests/sort_test.cc index a6926c2e6e487e..d4e18c7891c23e 100644 --- a/xla/tests/sort_test.cc +++ b/xla/tests/sort_test.cc @@ -41,7 +41,7 @@ XLA_TEST_F(SortTest, SortDim0) { } )"; - EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); } XLA_TEST_F(SortTest, SortDim1) { @@ -60,7 +60,7 @@ XLA_TEST_F(SortTest, SortDim1) { } )"; - EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); } } // namespace From 34a657cb0830ae45dc96479f88ae8576c7ef7bc0 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Thu, 25 Jul 2024 20:23:06 -0700 Subject: [PATCH 185/376] [XLA:UNSTACKER] Change existing unstacker patterns to support more cases. Previously, DSFusion and NestedDSFusion patterns would match only when the entire stacked operand was being read. With this change, we now support cases where the stacked operand is dynamically sliced where the dynamic-slice is effectively static at compile time. PiperOrigin-RevId: 656214194 --- xla/service/BUILD | 8 ++- xla/service/hlo_unstacker.cc | 78 +++++++++++++++++++------ xla/service/hlo_unstacker_test.cc | 60 +++++++++++++++++++ xla/service/while_loop_unroller.cc | 5 +- xla/service/while_loop_unroller.h | 4 +- xla/service/while_loop_unroller_test.cc | 6 +- 6 files changed, 133 insertions(+), 28 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index ed7e8303f3b5f6..09676869ddab91 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -3517,16 +3517,15 @@ cc_library( srcs = ["hlo_unstacker.cc"], hdrs = ["hlo_unstacker.h"], deps = [ - ":hlo_alias_analysis", - ":hlo_buffer", + ":algebraic_simplifier", ":hlo_creation_utils", - ":hlo_dce", ":hlo_pass", ":pattern_matcher", ":tuple_util", ":while_loop_unroller", "//xla:shape_util", "//xla:util", + "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3563,9 +3562,12 @@ cc_library( ":call_inliner", ":collective_ops_utils", ":flatten_call_graph", + ":hlo_alias_analysis", + ":hlo_buffer", ":hlo_creation_utils", ":hlo_cse", ":hlo_pass", + ":hlo_value", ":pattern_matcher", ":tuple_simplifier", ":while_loop_analysis", diff --git a/xla/service/hlo_unstacker.cc b/xla/service/hlo_unstacker.cc index cb969c56357845..024a41b0c48417 100644 --- a/xla/service/hlo_unstacker.cc +++ b/xla/service/hlo_unstacker.cc @@ -644,28 +644,67 @@ Shape MakeUnstackedShapeFromSlice(const Shape& slice_shape, int64_t layers) { } // Checks if the given instruction is a fusion with num_fusion_params -// parameters. If so, the function looks for the dynamic-index instruction -// within the fusion that covers the shape of the stacked operand at the given -// index. -HloInstruction* GetMostMajorShapeCoveringDynamicIndexInFusion( +// parameters inside an unrollable loop. If so, it returns the loop config. +std::optional IsFusionInsideUnrollableLoopWithNumParameter( const UnstackerMetadata& metadata, const HloInstruction* instr, - HloOpcode opcode, int64_t num_fusion_params, int64_t stacked_operand_idx) { + int64_t num_fusion_params) { if (instr->opcode() != HloOpcode::kFusion) { - return nullptr; + return std::nullopt; } if (instr->fused_parameters().size() != num_fusion_params) { VLOG(3) << "Fusion has different number of parameters"; - return nullptr; + return std::nullopt; } if (!metadata.unrollable_loop_bodies.contains(instr->parent())) { VLOG(5) << "Fusion not inside unrollable while body, " << instr->name() << " inside " << instr->parent()->name(); - return nullptr; + return std::nullopt; } + return metadata.unrollable_loop_bodies.at(instr->parent()); +} - WhileLoopConfig while_instr_config = - metadata.unrollable_loop_bodies.at(instr->parent()); +// Checks if the instruction is a fusion with num_fusion_params parameters +// inside an unrollable loop and within its fusion computation there is an +// effectively static dynamic-slice instruction on the most major dimension of +// the operand at the given stacked_operand_idx. If so, it returns the +// dynamic-slice instruction. +HloInstruction* GetMostMajorEffectivelyStaticDynamicSliceInFusion( + const UnstackerMetadata& metadata, const HloInstruction* instr, + int64_t num_fusion_params, int64_t stacked_operand_idx) { + std::optional while_instr_config = + IsFusionInsideUnrollableLoopWithNumParameter(metadata, instr, + num_fusion_params); + if (!while_instr_config.has_value()) { + return nullptr; + } + for (HloInstruction* fused_instr : + instr->fused_instructions_computation()->MakeInstructionPostOrder()) { + std::optional dynamic_index = + MatchEffectivelyStaticDynamicSliceInsideLoop( + fused_instr, + instr->fused_instructions_computation()->parameter_instruction( + stacked_operand_idx), + while_instr_config.value()); + if (dynamic_index.has_value() && dynamic_index.value() == 0) { + return fused_instr; + } + } + return nullptr; +} +// Checks if the instruction is a fusion with num_fusion_params parameters +// inside an unrollable loop and within its fusion computation looks for the +// dynamic-index instruction that covers the shape of the operand at the given +// index. +HloInstruction* GetMostMajorShapeCoveringDynamicIndexInFusion( + const UnstackerMetadata& metadata, const HloInstruction* instr, + HloOpcode opcode, int64_t num_fusion_params, int64_t stacked_operand_idx) { + std::optional while_instr_config = + IsFusionInsideUnrollableLoopWithNumParameter(metadata, instr, + num_fusion_params); + if (!while_instr_config.has_value()) { + return nullptr; + } for (HloInstruction* fused_instr : instr->fused_instructions_computation()->MakeInstructionPostOrder()) { if (fused_instr->opcode() != opcode) { @@ -676,7 +715,7 @@ HloInstruction* GetMostMajorShapeCoveringDynamicIndexInFusion( fused_instr, instr->fused_instructions_computation()->parameter_instruction( stacked_operand_idx), - opcode, while_instr_config); + opcode, while_instr_config.value()); if (dynamic_index.has_value() && dynamic_index.value() == 0) { return fused_instr; } @@ -685,20 +724,22 @@ HloInstruction* GetMostMajorShapeCoveringDynamicIndexInFusion( } // This function recognizes fusions with the following pattern: -// fusion(stacked, loop_iteration_var) +// fusion(stacked, f(loop_iteration_var)) // computation { // p0 = parameter(0) // p1 = parameter(1) // slice = dynamic_slice(p0, p1, zero, ...) // ROOT bitcast = bitcast(slice) // } +// where f is a function of loop_iteration_var. It indicates that the slicing +// offset is effectively static after unrolling. std::optional GetDSFusionPattern(const UnstackerMetadata& metadata, const HloInstruction* instr, int64_t stacked_operand_idx) { VLOG(3) << "Checking DSFusion"; HloInstruction* shape_covering_instr = - GetMostMajorShapeCoveringDynamicIndexInFusion( - metadata, instr, HloOpcode::kDynamicSlice, 2, stacked_operand_idx); + GetMostMajorEffectivelyStaticDynamicSliceInFusion(metadata, instr, 2, + stacked_operand_idx); if (shape_covering_instr == nullptr) { return std::nullopt; } @@ -1009,6 +1050,8 @@ std::optional GetNestedDSFusionPattern( WhileLoopConfig while_instr_config = metadata.unrollable_loop_bodies.at(instr->parent()); + VLOG(3) << "Checking NestedDSFusionPattern"; + HloInstruction* inner_fusion_user = nullptr; for (HloInstruction* fused_instr : instr->fused_instructions_computation()->MakeInstructionPostOrder()) { @@ -1035,11 +1078,11 @@ std::optional GetNestedDSFusionPattern( continue; } std::optional dynamic_index = - MatchShapeCoveringDynamicIndexInstruction( + MatchEffectivelyStaticDynamicSliceInsideLoop( inner_fusion_instr, inner_fusion_user->fused_instructions_computation() ->parameter_instruction(0), - HloOpcode::kDynamicSlice, while_instr_config); + while_instr_config); if (dynamic_index.has_value() && dynamic_index.value() == 0) { const int64_t num_layers = inner_fusion_user->operand(0)->shape().dimensions(0); @@ -1288,7 +1331,8 @@ absl::StatusOr HloUnstacker::Run( continue; } VLOG(3) << "Attempting to unstack " << loop->name() << " at " << i - << " = " << loop->while_init()->operand(i)->ToShortString(); + << " = " << loop->while_init()->operand(i)->shape().ToString(true) + << loop->while_init()->operand(i)->ToShortString(); unstacked |= UnstackWhileOperandAtIndex(metadata, loop, i, unstacked_instructions); VLOG(3) << "###################"; diff --git a/xla/service/hlo_unstacker_test.cc b/xla/service/hlo_unstacker_test.cc index d3d6eb06ba9754..84724550052dc1 100644 --- a/xla/service/hlo_unstacker_test.cc +++ b/xla/service/hlo_unstacker_test.cc @@ -257,6 +257,66 @@ TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUser) { std::nullopt, false)); } +// Instead of slicing the entire shape, this test slices only even elements from +// the first parameter. +TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDynamicIndex) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[6,128,128], p1: s32[]) -> s8[128,128] { + %param_0.51117 = s8[6,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[6,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) + } + + %fused_computation.inner (param_0.34523: bf16[8,128], param_1.30691: s8[6,128,128], p2: s32[]) -> bf16[8,128] { + %param_0.34523 = bf16[8,128] parameter(0) + %param_1.30691 = s8[6,128,128] parameter(1) + p2 = s32[] parameter(2) + %fusion.67830 = s8[128,128] fusion(s8[6,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice + ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[6,128,128])) -> (s32[], bf16[8,128], s8[6,128,128]) { + wide_p = (s32[], bf16[8,128], s8[6,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[6,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + two = s32[] constant(2) + mult = s32[] multiply(i, two) + fusion.conv = bf16[8,128] fusion(p0, p1, mult), kind=kOutput, calls=%fused_computation.inner + ROOT out = (s32[], bf16[8,128], s8[6,128,128]) tuple(inc, fusion.conv, p1) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[6,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[6,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[6,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[6,128,128]) tuple(init, p1, p0) + while.out = (s32[], bf16[8,128], s8[6,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[6,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); + EXPECT_TRUE(unstacked); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), + std::nullopt, false)); +} + TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserMultipleIndex) { std::string hlo_string = R"( HloModule SimpleLoop diff --git a/xla/service/while_loop_unroller.cc b/xla/service/while_loop_unroller.cc index 534d2c604adb84..0e1c3288c468df 100644 --- a/xla/service/while_loop_unroller.cc +++ b/xla/service/while_loop_unroller.cc @@ -407,8 +407,11 @@ bool IsEffectivelyStatic(const HloInstruction* instr, } std::optional MatchEffectivelyStaticDynamicSliceInsideLoop( - const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode, + const HloInstruction* instr, const HloInstruction* input, const WhileLoopConfig& config) { + if (instr->opcode() != HloOpcode::kDynamicSlice) { + return std::nullopt; + } int64_t start_indices_offset = 1; const HloInstruction* operand = instr->operand(0); if (operand != input) { diff --git a/xla/service/while_loop_unroller.h b/xla/service/while_loop_unroller.h index 1092dd791924cf..2e75b0648053a9 100644 --- a/xla/service/while_loop_unroller.h +++ b/xla/service/while_loop_unroller.h @@ -27,9 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal_util.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/service/pattern_matcher.h" namespace xla { @@ -61,7 +59,7 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( // involves the iteration variable of the surrounding loop and some constants, // if we unroll the surrounding loop. If so, it returns the dynamic index. std::optional MatchEffectivelyStaticDynamicSliceInsideLoop( - const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode, + const HloInstruction* instr, const HloInstruction* input, const WhileLoopConfig& config); // This pass unrolls while loops with the given unrolling factor. The value of diff --git a/xla/service/while_loop_unroller_test.cc b/xla/service/while_loop_unroller_test.cc index 7029172e05f2df..7758d1e845e082 100644 --- a/xla/service/while_loop_unroller_test.cc +++ b/xla/service/while_loop_unroller_test.cc @@ -1176,16 +1176,14 @@ TEST_F(WhileLoopUnrollerTest, IsEffectivelyStaticDynamicSlice) { comp->GetInstructionWithName("dynamic-slice.static"); if (static_slice != nullptr) { auto index = MatchEffectivelyStaticDynamicSliceInsideLoop( - static_slice, static_slice->operand(0), HloOpcode::kDynamicSlice, - *config); + static_slice, static_slice->operand(0), *config); EXPECT_TRUE(index.has_value()); } HloInstruction* dynamic_slice = comp->GetInstructionWithName("dynamic-slice.dynamic"); if (dynamic_slice != nullptr) { auto index = MatchEffectivelyStaticDynamicSliceInsideLoop( - dynamic_slice, dynamic_slice->operand(0), HloOpcode::kDynamicSlice, - *config); + dynamic_slice, dynamic_slice->operand(0), *config); EXPECT_FALSE(index.has_value()); } } From 04f2bfe797408c9efe742b89e2e4db6cf526ebb7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Jul 2024 21:09:26 -0700 Subject: [PATCH 186/376] Add long polling as a new way to propagate error in coordination service. PiperOrigin-RevId: 656228596 --- .../tsl/protobuf/coordination_config.proto | 4 + .../tsl/protobuf/coordination_service.proto | 18 ++ .../distributed_runtime/coordination/BUILD | 9 +- .../coordination/coordination_client.h | 6 + .../coordination/coordination_service.cc | 243 +++++++++++++++--- .../coordination/coordination_service.h | 12 + .../coordination_service_agent.cc | 115 ++++++++- .../coordination_service_agent_test.cc | 33 +++ .../coordination_service_rpc_handler.cc | 14 + .../coordination_service_rpc_handler.h | 4 + .../coordination/coordination_service_test.cc | 233 +++++++++++++++++ .../coordination/grpc_coordination_client.cc | 13 + .../grpc_coordination_service_impl.cc | 1 + .../grpc_coordination_service_impl.h | 1 + 14 files changed, 662 insertions(+), 44 deletions(-) diff --git a/third_party/tsl/tsl/protobuf/coordination_config.proto b/third_party/tsl/tsl/protobuf/coordination_config.proto index 035a49e6f20e9c..23aff65eb67985 100644 --- a/third_party/tsl/tsl/protobuf/coordination_config.proto +++ b/third_party/tsl/tsl/protobuf/coordination_config.proto @@ -67,4 +67,8 @@ message CoordinationServiceConfig { // not specify any config. This field allows users to explicitly disable // coordination service under all situations. bool force_disable = 12; + + // Use long polling to get error from coordination service as the error + // propagation mechanism. + bool poll_for_error_from_service_at_startup = 13; } diff --git a/third_party/tsl/tsl/protobuf/coordination_service.proto b/third_party/tsl/tsl/protobuf/coordination_service.proto index 2f7cc804cae5cb..2405cb936d8472 100644 --- a/third_party/tsl/tsl/protobuf/coordination_service.proto +++ b/third_party/tsl/tsl/protobuf/coordination_service.proto @@ -84,6 +84,12 @@ message HeartbeatResponse { // broadcast error code and message to other tasks. } +message PollForErrorRequest { + CoordinatedTask source_task = 1; +} + +message PollForErrorResponse {} + // Request and response messages for waiting for all tasks. message WaitForAllTasksRequest { // Removed fields which used to specify the remote task. @@ -342,4 +348,16 @@ service CoordinationService { // Possible service errors: // - FailedPrecondition: Barrier has already been passed. rpc CancelBarrier(CancelBarrierRequest) returns (CancelBarrierResponse); + + // Polls the service for errors. + // + // This RPC is used by the coordination service agent to send long polling + // request to service for errors. The call will block until an error is + // reported by the service. + // + // Possible service errors: + // - Aborted: Service is shutting down. + rpc PollForError(PollForErrorRequest) returns (PollForErrorResponse) { + // [AUTOMATION]: Internal rpc option goes here. + } } diff --git a/xla/tsl/distributed_runtime/coordination/BUILD b/xla/tsl/distributed_runtime/coordination/BUILD index 3528a89a5103fa..d42d285d1d9a69 100644 --- a/xla/tsl/distributed_runtime/coordination/BUILD +++ b/xla/tsl/distributed_runtime/coordination/BUILD @@ -53,7 +53,6 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@tsl//tsl/platform:macros", "@tsl//tsl/platform:status", @@ -78,13 +77,13 @@ tsl_gpu_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@tsl//tsl/platform:env", - "@tsl//tsl/platform:macros", "@tsl//tsl/platform:random", "@tsl//tsl/platform:status", "@tsl//tsl/protobuf:coordination_config_proto_cc", @@ -145,6 +144,7 @@ tsl_gpu_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -171,12 +171,10 @@ tsl_cc_test( "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:env_impl", - "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:status", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -201,7 +199,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@tsl//tsl/platform:casts", "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:status", "@tsl//tsl/platform:thread_annotations", diff --git a/xla/tsl/distributed_runtime/coordination/coordination_client.h b/xla/tsl/distributed_runtime/coordination/coordination_client.h index 3dcb8623d81f5d..cea5ba4890d37b 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_client.h +++ b/xla/tsl/distributed_runtime/coordination/coordination_client.h @@ -40,6 +40,8 @@ using tensorflow::HeartbeatRequest; using tensorflow::HeartbeatResponse; using tensorflow::InsertKeyValueRequest; using tensorflow::InsertKeyValueResponse; +using tensorflow::PollForErrorRequest; +using tensorflow::PollForErrorResponse; using tensorflow::RegisterTaskRequest; using tensorflow::RegisterTaskResponse; using tensorflow::ReportErrorToServiceRequest; @@ -124,6 +126,10 @@ class CoordinationClient { virtual void CancelBarrierAsync(const CancelBarrierRequest* request, CancelBarrierResponse* response, StatusCallback done) = 0; + virtual void PollForErrorAsync(CallOptions* call_opts, + const PollForErrorRequest* request, + PollForErrorResponse* response, + StatusCallback done) = 0; }; // Simple wrapper class that can be used to retrieve CoordinationClients. diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/xla/tsl/distributed_runtime/coordination/coordination_service.cc index cfd672266d44ce..c70dde5e12d3b3 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -31,12 +31,16 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" @@ -139,6 +143,8 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { StatusCallback done) override; absl::Status CancelBarrier(std::string_view barrier_id, const CoordinatedTask& task) override; + void PollForErrorAsync(const CoordinatedTask& task, + StatusCallback done) override; private: const DeviceInfo& ListClusterDevices() override @@ -184,6 +190,39 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { CoordinatedTaskEqual>& tasks_at_barrier, int64_t cluster_size); bool isRecoverableJob(std::string_view task_name) const; + // Sends responses to error polling requests when an error is encountered. + void SendErrorPollingResponse(const absl::Status& error); + // Responds to error polling or stops the service when an error is + // encountered. Should only be called when there is no service to client + // connection. Returns true if the service stops, otherwise returns false. + bool SendErrorPollingResponseOrStopService(const absl::Status& error); + // Returns whether the clients are polling for error from the service. If the + // clients are not polling for error from the service, the service should stop + // when there is an error. Otherwise, the service should not stop. + bool IsClientPollingForError() const; + + class ErrorPollingState { + public: + // Returns whether the error polling requests have been responded. + bool Responded() const { return responded_; } + // Sets the error and executes the status callbacks. + void SetError(const absl::Status& error); + // Gets the error that is propagated to the agents. + const absl::Status& GetError() const { return error_; } + // Returns true if the task has sent request to poll for error from the + // service. + bool IsTaskPolling(absl::string_view task_name) const { + return polling_task_names_.contains(task_name); + } + // Adds a task to the error polling state. + void AddTask(const CoordinatedTask& task, StatusCallback&& done); + + private: + bool responded_ = false; + absl::Status error_ = absl::OkStatus(); + std::vector done_callbacks_; + absl::flat_hash_set polling_task_names_; + }; class TaskState { public: @@ -246,6 +285,10 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { // silently if configured. This is useful when we know that a task can // immediately resume work upon re-connecting to the service. bool allow_new_incarnation_to_reconnect_ = false; + // Whether the agents are polling for error from the service. It will be set + // to true when the service sees the first error polling request. Once set to + // true, the value will never change back to false, so no mutex is needed. + bool client_polling_for_error_ = false; std::function post_aggregate_device_fn_; @@ -277,11 +320,32 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { absl::flat_hash_set recoverable_jobs_; + ErrorPollingState error_polling_state_ ABSL_GUARDED_BY(state_mu_); + CoordinationServiceStandaloneImpl(const CoordinationServiceStandaloneImpl&) = delete; void operator=(const CoordinationServiceStandaloneImpl&) = delete; }; +void CoordinationServiceStandaloneImpl::ErrorPollingState::SetError( + const absl::Status& error) { + if (responded_) return; + responded_ = true; + error_ = error; + for (auto& done_cb : done_callbacks_) { + done_cb(error_); + } + done_callbacks_.clear(); +} + +void CoordinationServiceStandaloneImpl::ErrorPollingState::AddTask( + const CoordinatedTask& task, StatusCallback&& done) { + // Do not allow to insert a task if the service has already responded. + if (Responded()) return; + polling_task_names_.insert(GetTaskName(task)); + done_callbacks_.emplace_back(done); +} + void CoordinationServiceStandaloneImpl::TaskState::SetConnected( uint64_t task_incarnation) { state_ = CoordinatedTaskState::TASKSTATE_CONNECTED; @@ -414,11 +478,9 @@ void CoordinationServiceStandaloneImpl::StartCheckStaleness() { status = MakeCoordinationError(absl::UnavailableError( absl::StrCat("Task ", task_name, " heartbeat timeout. This indicates that the " - "remote task " - "has failed, got preempted, or crashed " - "unexpectedly. Check " - "the task logs for an earlier error to debug " - "further."))); + "remote task has failed, got preempted, or " + "crashed unexpectedly. Check the task logs " + "for an earlier error to debug further."))); SetTaskError(task_name, status); } } @@ -426,23 +488,23 @@ void CoordinationServiceStandaloneImpl::StartCheckStaleness() { // Propagate heartbeat timeout errors to other connected tasks. if (!stale_task_names.empty()) { if (!has_service_to_client_connection) { - // Error cannot be propagated since there is no service-to-client - // connection, so shut down service instead. Note: we cannot - // destroy the thread within its own function. However, this - // thread will be destroyed once the function returns. - LOG(ERROR) - << "Stopping coordination service as the following tasks are " - "unhealthy (stopped sending heartbeats):\n" - << absl::StrJoin(stale_task_names, "\n") - << "\nCheck the task logs for an earlier error to debug " - "further."; - Stop(/*shut_staleness_thread=*/false); - return; - } - for (const auto& stale_task_name : stale_task_names) { - PropagateError(GetTaskFromName(stale_task_name)); + absl::Status heartbeat_timeout_error = + MakeCoordinationError(absl::UnavailableError(absl::StrCat( + "The following tasks are unhealthy (stopped sending " + "heartbeats):\n", + absl::StrJoin(stale_task_names, "\n"), + "\nCheck the task logs for an earlier error to debug " + "further."))); + if (SendErrorPollingResponseOrStopService( + heartbeat_timeout_error)) { + return; + } + } else { + for (const auto& stale_task_name : stale_task_names) { + PropagateError(GetTaskFromName(stale_task_name)); + } + stale_task_names.clear(); } - stale_task_names.clear(); } // Barrier timeout check. @@ -479,15 +541,15 @@ void CoordinationServiceStandaloneImpl::StartCheckStaleness() { } if (!has_service_to_client_connection && expired_barriers.contains(shutdown_barrier_id_)) { - // Error cannot be propagated since there is no service-to-client - // connection, so shut down service instead. Note: we cannot - // destroy the thread within its own function. However, this - // thread will be destroyed once the function returns. - LOG(ERROR) - << "Stopping coordination service as shutdown barrier " - "timed out. Check the task logs for an earlier error."; - Stop(/*shut_staleness_thread=*/false); + // Error cannot be propagated through service-to-client connection. + // Note: we cannot destroy the thread within its own function. + // However, this thread will be destroyed once the function returns. + SendErrorPollingResponseOrStopService( + MakeCoordinationError(absl::DeadlineExceededError( + "Shutdown barrier timed out. Check the task logs for an " + "earlier error."))); } + // Reset this for the next barrier check. expired_barriers.clear(); } @@ -529,6 +591,12 @@ void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) { // the state is used in `PassBarrier`. cluster_state_.clear(); } + // Cancel all pending PollForErrorAsync() calls. + if (IsClientPollingForError()) { + SendErrorPollingResponse( + absl::CancelledError("Coordination service is shutting down. " + "Cancelling PollForErrorAsync()")); + } // Destroy thread outside of the mutex. if (shut_staleness_thread) { check_staleness_thread_.reset(); @@ -666,6 +734,7 @@ void CoordinationServiceStandaloneImpl::WaitForAllTasks( void CoordinationServiceStandaloneImpl::ShutdownTaskAsync( const CoordinatedTask& task, StatusCallback done) { + VLOG(3) << "Task " << GetTaskName(task) << " invoked ShutdownTaskAsync()"; if (shutdown_barrier_timeout_ > absl::ZeroDuration()) { // Impose shutdown barrier so that all tasks can disconnect together. BarrierAsync(shutdown_barrier_id_, shutdown_barrier_timeout_, task, {}, @@ -813,6 +882,8 @@ absl::Status CoordinationServiceStandaloneImpl::RecordHeartbeat( "Task with task_name=", task_name, " must be registered before sending heartbeat messages"))); } + VLOG(10) << "Record heartbeat from task: " << task_name + << "at incarnation: " << incarnation << "at " << absl::Now(); s = cluster_state_[task_name]->RecordHeartbeat(incarnation); } @@ -862,6 +933,7 @@ void CoordinationServiceStandaloneImpl::ReportServiceErrorToTaskAsync( void CoordinationServiceStandaloneImpl::PropagateError( const CoordinatedTask& source_task, bool is_reported_by_task) { + VLOG(3) << "PropagateError() from " << GetTaskName(source_task); // If the error task is recoverable, do not propagate the error to other // connected tasks. if (isRecoverableJob(source_task.job_name())) return; @@ -898,15 +970,13 @@ void CoordinationServiceStandaloneImpl::PropagateError( continue; } - // Don't propagate error if there is no service-to-client connection. + // If there is no service-to-client connection, use error polling or stop + // the service. if (client_cache_ == nullptr) { - LOG(ERROR) - << "Stopping coordination service as there is no " - "service-to-client connection, but we encountered an error: " - << error; - Stop(/*shut_staleness_thread=*/false); + SendErrorPollingResponseOrStopService(error); return; } + CoordinationClient* client = client_cache_->GetClient(std::string(task)); auto response = std::make_shared(); auto n = std::make_shared(); @@ -1075,6 +1145,50 @@ void CoordinationServiceStandaloneImpl::SetTaskError(std::string_view task_name, << " has been set to ERROR in coordination service: " << error; } +void CoordinationServiceStandaloneImpl::PollForErrorAsync( + const CoordinatedTask& task, StatusCallback done) { + const std::string task_name = GetTaskName(task); + VLOG(3) << "Task " << task_name << " invoked PollForErrorAsync()."; + + absl::MutexLock l(&state_mu_); + if (ServiceHasStopped()) { + done(MakeCoordinationError(absl::InternalError( + "PollForError requested after coordination service has shut down."))); + return; + } + + if (client_cache_ != nullptr) { + done(MakeCoordinationError( + absl::InternalError("Should not use error polling from service when " + "there is service to client connection."))); + return; + } + + client_polling_for_error_ = true; + + if (!cluster_state_.contains(task_name)) { + done(MakeCoordinationError(absl::InvalidArgumentError( + absl::StrCat("Unexpected task (", task_name, + ") that is not in the cluster polling for errors.")))); + return; + } + + if (cluster_state_[task_name]->GetState() != + CoordinatedTaskState::TASKSTATE_CONNECTED) { + done(MakeCoordinationError(absl::InvalidArgumentError( + absl::StrCat("Task (", task_name, + ") that has not been registered polling for errors.")))); + return; + } + + if (error_polling_state_.Responded()) { + done(error_polling_state_.GetError()); + return; + } + + error_polling_state_.AddTask(task, std::move(done)); +} + void CoordinationServiceStandaloneImpl::BarrierAsync( std::string_view barrier_id, absl::Duration timeout, const CoordinatedTask& task, @@ -1325,8 +1439,40 @@ void CoordinationServiceStandaloneImpl::PassBarrier(std::string_view barrier_id, barrier->done_callbacks.clear(); } -bool CoordinationServiceStandaloneImpl::ValidateTaskArgs( +void CoordinationServiceStandaloneImpl::SendErrorPollingResponse( + const absl::Status& error) { + CHECK(IsClientPollingForError()) + << "`SendErrorPollingResponse` should only be called after agents poll " + "errors from the service."; + { + absl::MutexLock l(&state_mu_); + if (error_polling_state_.Responded()) { + return; + } + } + LOG(ERROR) << "An error is encountered. Sending the error as a response to " + "all error polling requests: " + << error; + std::vector missing_tasks; + { + absl::MutexLock l(&state_mu_); + missing_tasks.reserve(cluster_state_.size()); + for (const auto& [task_name, task_state] : cluster_state_) { + if (!error_polling_state_.IsTaskPolling(task_name)) { + missing_tasks.push_back(task_name); + } + } + error_polling_state_.SetError(error); + } + if (!missing_tasks.empty()) { + LOG(ERROR) << absl::StrFormat( + "The following %d tasks in the cluster has not sent request to poll " + "for error. Error will not be propagated to these tasks: %s", + missing_tasks.size(), absl::StrJoin(missing_tasks, ",")); + } +} +bool CoordinationServiceStandaloneImpl::ValidateTaskArgs( const std::vector& tasks_args, const absl::flat_hash_map& tasks_at_barrier, @@ -1385,6 +1531,31 @@ bool CoordinationServiceStandaloneImpl::isRecoverableJob( return recoverable_jobs_.find(task_name) != recoverable_jobs_.end(); } +bool CoordinationServiceStandaloneImpl::SendErrorPollingResponseOrStopService( + const absl::Status& error) { + CHECK(!error.ok()) << "SendErrorPollingResponseOrStopService called with OK " + "status. Should always return an error."; + // Should be called only when there is no service-to-client connection. + assert(client_cache_ == nullptr); + if (IsClientPollingForError()) { + LOG(ERROR) + << "Use error polling to propagate the following error to all tasks: " + << error; + SendErrorPollingResponse(error); + return false; + } + + LOG(ERROR) << "Stopping coordination service as there is no " + "service-to-client connection, but we encountered an error: " + << error; + Stop(/*shut_staleness_thread=*/false); + return true; +} + +bool CoordinationServiceStandaloneImpl::IsClientPollingForError() const { + return client_polling_for_error_; +} + // Register standalone coordination service implementation. REGISTER_COORDINATION_SERVICE("standalone", EnableCoordinationService); diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service.h b/xla/tsl/distributed_runtime/coordination/coordination_service.h index 63bea88dbd5eaf..0c1fe046dd6b43 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service.h +++ b/xla/tsl/distributed_runtime/coordination/coordination_service.h @@ -237,6 +237,18 @@ class CoordinationServiceInterface { virtual absl::Status CancelBarrier( std::string_view barrier_id, const tensorflow::CoordinatedTask& task) = 0; + // Gets error from the coordination service. Block until the service + // returns an error or the task/service is shutdown. This should never be used + // when there is service to client connection (i.e. `CoordinationClientCache` + // is passed in during construction). + // + // The first call to this function will trigger the error polling mode in the + // coordination service, so once an error occurs after the first call, the + // service will use the error polling mode to propagate the error to all + // connected tasks instead of simply shutting down. + virtual void PollForErrorAsync(const tensorflow::CoordinatedTask& task, + StatusCallback done) = 0; + private: friend class CoordinationServiceRpcHandler; friend class CoordinationServiceTest_ListClusterDevices_TfDevice_Test; diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc index 2681333105991f..8bcf451987dc81 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc @@ -30,9 +30,11 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/bind_front.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" @@ -67,6 +69,7 @@ constexpr absl::Duration kDefaultClusterRegisterTimeout = absl::Hours(1); constexpr absl::Duration kDefaultHeartbeatTimeout = absl::Seconds(10); constexpr absl::Duration kDefaultShutdownTimeout = absl::Seconds(10); constexpr char kHeartbeatThread[] = "CoordinationServiceHeartbeatLoop"; +constexpr char kErrorPollingThread[] = "CoordinationServiceErrorPolling"; class CoordinationServiceAgentImpl : public CoordinationServiceAgent { public: @@ -143,6 +146,17 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { absl::Status ShutdownInternal(); // Starts sending heartbeats to the coordination service. void StartSendingHeartbeats(); + // Use long polling to get error from the coordination service. This function + // will block until an error is received or the agent is shutdown or reset. + absl::Status PollForError(); + std::shared_ptr PollForErrorAsync(StatusCallback done); + + // Starts polling for error from the coordination service. + void StartPollingForError(); + // Cancels the error polling request and stops the error polling thread. + void StopErrorPolling(); + // Resets the cancellation manager for error polling. + void ResetCancellationManager(); Env* env_ = nullptr; // Not owned. const uint64_t incarnation_id_ = random::New64(); @@ -166,9 +180,12 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { absl::CondVar heartbeat_thread_cv_; bool shutting_down_ TF_GUARDED_BY(heartbeat_thread_shutdown_mu_) = false; std::unique_ptr heartbeat_thread_; + std::unique_ptr error_polling_thread_; // Must outlive coordination client which may need to access it within // GetKeyValueAsync() callbacks. CancellationManager cancellation_manager_; + std::unique_ptr error_polling_cancellation_manager_ = + std::make_unique(); std::unique_ptr leader_client_; CoordinationServiceAgentImpl(const CoordinationServiceAgentImpl&) = delete; @@ -239,6 +256,16 @@ void CoordinationServiceAgentImpl::StopHeartbeat() { heartbeat_thread_ = nullptr; } +void CoordinationServiceAgentImpl::StopErrorPolling() { + // Cancel pending error polling RPC call. + error_polling_cancellation_manager_->StartCancel(); + error_polling_thread_ = nullptr; +} + +void CoordinationServiceAgentImpl::ResetCancellationManager() { + error_polling_cancellation_manager_ = std::make_unique(); +} + absl::Status CoordinationServiceAgentImpl::Connect() { VLOG(3) << "Agent has started trying to Connect()."; { @@ -315,6 +342,13 @@ absl::Status CoordinationServiceAgentImpl::Connect() { ThreadOptions(), kHeartbeatThread, absl::bind_front(&CoordinationServiceAgentImpl::StartSendingHeartbeats, this))); + if (configs_.poll_for_error_from_service_at_startup()) { + // Start a thread to poll for error from the coordination service. + error_polling_thread_.reset(env_->StartThread( + ThreadOptions(), kErrorPollingThread, + absl::bind_front(&CoordinationServiceAgentImpl::StartPollingForError, + this))); + } return absl::OkStatus(); } @@ -377,6 +411,80 @@ void CoordinationServiceAgentImpl::StartSendingHeartbeats() { } } +void CoordinationServiceAgentImpl::StartPollingForError() { + LOG(INFO) << "Polling error from coordination service. This thread " + "will run until an error is encountered or the agent is " + "shutdown."; + absl::Status status = PollForError(); + CHECK(!status.ok()) << "PollForError returned OK status. Should " + "always return an error."; + if (absl::IsCancelled(status)) { + LOG(INFO) << "Stop polling error from coordination service because " + "the service or the agent is shutting down." + << status; + return; + } + LOG(INFO) << "Error returned from coordination service after polling: " + << status; + + SetError(status); +} + +absl::Status CoordinationServiceAgentImpl::PollForError() { + absl::Status status = absl::OkStatus(); + absl::Notification n; + PollForErrorAsync([&](absl::Status s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + CHECK(!status.ok()) + << "PollForError returned OK status. Should always return an error."; + LOG(ERROR) + << "PollForError returned with status (this can be an error from this or " + "another task): " + << status; + return status; +} + +std::shared_ptr CoordinationServiceAgentImpl::PollForErrorAsync( + StatusCallback done) { + auto call_opts = std::make_shared(); + + absl::Status agent_running_status = + ValidateRunningAgent(/*allow_disconnected=*/true); + if (!agent_running_status.ok()) { + done(agent_running_status); + return call_opts; + } + auto request = std::make_shared(); + auto response = std::make_shared(); + *request->mutable_source_task() = task_; + VLOG(3) << "PollForErrorRequest: " << request->DebugString(); + + const CancellationToken token = + error_polling_cancellation_manager_->get_cancellation_token(); + const bool already_cancelled = + !error_polling_cancellation_manager_->RegisterCallback( + token, [call_opts]() { call_opts->StartCancel(); }); + if (already_cancelled) { + done(absl::CancelledError("PollForErrorAsync() was cancelled.")); + return call_opts; + } + + leader_client_->PollForErrorAsync( + call_opts.get(), request.get(), response.get(), + [call_opts, request, response, done = std::move(done), + &cm = error_polling_cancellation_manager_, + token](const absl::Status& s) { + // RPC call has completed (no longer needs to be cancelled if agent is + // destroyed). + cm->TryDeregisterCallback(token); + done(s); + }); + return call_opts; +} + absl::Status CoordinationServiceAgentImpl::WaitForAllTasks( const DeviceInfo& local_devices) { absl::Status agent_running_status = ValidateRunningAgent(); @@ -529,13 +637,14 @@ absl::Status CoordinationServiceAgentImpl::ShutdownInternal() { // Tear down agent. StopHeartbeat(); + StopErrorPolling(); { absl::MutexLock l(&state_mu_); if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) { const std::string status_message = absl::StrCat( "Shutdown() was called while coordination agent is in error state, " - "implying that distributed execution failed. Note: agent will still " - "shutdown anyway. Agent status: ", + "implying that distributed execution failed. Note: agent will " + "still shutdown anyway. Agent status: ", status_.ToString(), "\nThis is usually caused by an earlier error during execution. " "Check the logs (this task or the leader) for an earlier error to " @@ -581,6 +690,8 @@ absl::Status CoordinationServiceAgentImpl::Reset() { // Reset agent state. StopHeartbeat(); + StopErrorPolling(); + ResetCancellationManager(); { absl::MutexLock l(&state_mu_); state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED; diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc index db0cadf874bacc..6348054527fdb8 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc @@ -133,6 +133,10 @@ class TestCoordinationClient : public CoordinationClient { (CallOptions*, const HeartbeatRequest*, HeartbeatResponse*, StatusCallback), (override)); + MOCK_METHOD(void, PollForErrorAsync, + (CallOptions * call_opts, const PollForErrorRequest*, + PollForErrorResponse*, StatusCallback), + (override)); #define UNIMPLEMENTED(method) \ void method##Async(const method##Request* request, \ @@ -421,6 +425,35 @@ TEST_F(CoordinationServiceAgentTest, ConnectAfterResetError) { TF_EXPECT_OK(agent_->Connect()); } +TEST_F(CoordinationServiceAgentTest, ConnectAfterReset_WithErrorPolling) { + // Connect coordination agent and set it to error. + PollForErrorResponse mocked_response; + EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _)) + .WillOnce(DoAll(SetArgPointee<2>(mocked_response), + InvokeArgument<3>(absl::UnavailableError("Test Error.")))) + .WillOnce(DoAll(SetArgPointee<2>(mocked_response), + InvokeArgument<3>(absl::InternalError("Test Error.")))); + + CoordinationServiceConfig config; + config.set_poll_for_error_from_service_at_startup(true); + InitializeAgent(config); + // The agent will be in ERROR state after the first call to Connect() because + // the error polling thread will be created and will immediately return an + // error. + TF_ASSERT_OK(agent_->Connect()); + // Wait a bit for the error polling thread to start. + absl::SleepFor(absl::Seconds(2)); + ASSERT_TRUE(agent_->IsError()); + + TF_ASSERT_OK(agent_->Reset()); + // Agent should be able to reconnect to the service after resetting. The error + // polling thread will be recreated when the agent is connected again. + TF_EXPECT_OK(agent_->Connect()); + absl::SleepFor(absl::Seconds(2)); + // The agent should again be in ERROR state after Connect(). + EXPECT_TRUE(agent_->IsError()); +} + TEST_F(CoordinationServiceAgentTest, ResetCanBeRetried) { // Mock reset error failing for the first time. EXPECT_CALL(*GetClient(), ResetTaskAsync(_, _, _)) diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc index 315947473cf51c..200db9df7ee232 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc @@ -301,4 +301,18 @@ void CoordinationServiceRpcHandler::CancelBarrierAsync( done(service_->CancelBarrier(request->barrier_id(), request->source_task())); } +void CoordinationServiceRpcHandler::PollForErrorAsync( + const tensorflow::PollForErrorRequest* request, + tensorflow::PollForErrorResponse* response, StatusCallback done) { + absl::ReaderMutexLock l(&mu_); + if (service_ == nullptr) { + done(MakeCoordinationError( + absl::InternalError("Coordination service is not enabled."))); + return; + } + service_->PollForErrorAsync( + request->source_task(), + [done = std::move(done)](const absl::Status& status) { done(status); }); +} + } // namespace tsl diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h b/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h index 51d4f9f6901dc6..537a5d5be3a652 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h @@ -92,6 +92,10 @@ class CoordinationServiceRpcHandler { tensorflow::CancelBarrierResponse* response, StatusCallback done); + void PollForErrorAsync(const tensorflow::PollForErrorRequest* request, + tensorflow::PollForErrorResponse* response, + StatusCallback done); + private: absl::Mutex mu_; CoordinationServiceAgent* agent_ TF_GUARDED_BY(mu_) = nullptr; diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 3fcc8cb3fd0a36..9d02ce3641d4eb 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -47,9 +47,12 @@ limitations under the License. namespace tsl { namespace { +using ::testing::Each; using ::testing::EqualsProto; +using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; +using ::testing::status::StatusIs; using tensorflow::CoordinatedJob; using tensorflow::CoordinatedTask; @@ -131,6 +134,7 @@ class TestCoordinationClient : public CoordinationClient { UNIMPLEMENTED_WITH_CALL_OPTS(GetKeyValue); UNIMPLEMENTED_WITH_CALL_OPTS(Heartbeat); UNIMPLEMENTED_WITH_CALL_OPTS(ShutdownTask); + UNIMPLEMENTED_WITH_CALL_OPTS(PollForError); #undef UNIMPLEMENTED_WITH_CALL_OPTS private: @@ -426,6 +430,8 @@ TEST(CoordinationServiceTest, EXPECT_TRUE(!status.message().empty()); } +// TODO(b/195990880): Remove this test once server-client connection is removed. +// This test passes only when there is a single task. TEST(CoordinationServiceTest, RegisterTask_AlreadyInError_Fails) { CoordinationServiceConfig config = GetCoordinationServiceConfig(/*num_tasks=*/1); @@ -464,6 +470,29 @@ TEST_F(CoordinateTwoTasksTest, TestTaskHeartbeatTimeout) { coord_service_->RecordHeartbeat(task_1_, incarnation_1_))); } +TEST_F(CoordinateTwoTasksTest, + ErrorPollingRequestsGotCancelledErrorUponServiceShutdown) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + std::vector statuses; + statuses.reserve(2); + + for (const CoordinatedTask& task : {task_0_, task_1_}) { + coord_service_->PollForErrorAsync( + task, [&](const absl::Status& status) { statuses.push_back(status); }); + } + + // No error polling requests are received before service shutdown. + EXPECT_EQ(statuses.size(), 0); + coord_service_.reset(); + + // The service shutdowns successfully and send the cancellation response to + // the error polling requests. + EXPECT_EQ(statuses.size(), 2); + EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kCancelled))); +} + TEST_F(CoordinateTwoTasksTest, HeartbeatTimeoutWithoutServerToClientConnection) { EnableCoordinationService(/*has_service_to_client_connection=*/false); @@ -482,6 +511,76 @@ TEST_F(CoordinateTwoTasksTest, coord_service_->RecordHeartbeat(task_1_, incarnation_1_))); } +TEST_F(CoordinateTwoTasksTest, + HeartbeatTimeoutErrorCanPropagateThroughErrorPolling) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + std::vector statuses; + statuses.reserve(2); + for (const CoordinatedTask& task : {task_0_, task_1_}) { + coord_service_->PollForErrorAsync( + task, [&](const absl::Status& status) { statuses.push_back(status); }); + } + + // No heartbeat for a while, leader consider the task as stale and propagate + // the error to the tasks. + Env::Default()->SleepForMicroseconds( + absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); + + // The heartbeat error is propagated through error polling. + EXPECT_EQ(statuses.size(), 2); + EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kUnavailable))); +} + +TEST_F(CoordinateTwoTasksTest, + HeartbeatTimeoutErrorFromOneTaskCanPropagateThroughErrorPolling) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + std::vector statuses; + statuses.reserve(2); + + for (const CoordinatedTask& task : {task_0_, task_1_}) { + coord_service_->PollForErrorAsync( + task, [&](const absl::Status& status) { statuses.push_back(status); }); + } + + // Use a factor of 0.9 to avoid accidental timeout. + const int64_t sleeping_time = + absl::ToInt64Microseconds(0.9 * kHeartbeatTimeout); + // No heartbeat from task 1 for a while, so leader consider the task as stale + // and propagate the error to all tasks. + Env::Default()->SleepForMicroseconds(sleeping_time); + TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); + Env::Default()->SleepForMicroseconds(sleeping_time); + TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); + Env::Default()->SleepForMicroseconds(sleeping_time); + + // The heartbeat error is propagated through error polling. + EXPECT_EQ(statuses.size(), 2); + EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kUnavailable, + HasSubstr("task:")))); +} + +TEST_F(CoordinateTwoTasksTest, ReportedErrorCanPropagateThroughErrorPolling) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + std::vector statuses; + statuses.reserve(2); + for (const CoordinatedTask& task : {task_0_, task_1_}) { + coord_service_->PollForErrorAsync( + task, [&](const absl::Status& status) { statuses.push_back(status); }); + } + + ASSERT_OK(coord_service_->ReportTaskError(task_1_, + absl::InternalError("test_error"))); + // The reported error is propagated through error polling. + EXPECT_EQ(statuses.size(), 2); + EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kInternal))); +} + TEST_F(CoordinateTwoTasksTest, TestTaskRestart) { EnableCoordinationService(); ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); @@ -1493,6 +1592,37 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsIfServiceHasStopped) { EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; } +TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + std::vector statuses; + statuses.reserve(2); + for (const CoordinatedTask& task : {task_0_, task_1_}) { + coord_service_->PollForErrorAsync( + task, [&](const absl::Status& status) { statuses.push_back(status); }); + } + // No heartbeat for a while, leader consider the task as stale. The error will + // be propagated through error polling. + Env::Default()->SleepForMicroseconds( + absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); + + EXPECT_EQ(statuses.size(), 2); + EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kUnavailable))); + + absl::Notification n0; + absl::Status barrier_status; + // Barrier should fail when called after the error is propagated. + coord_service_->BarrierAsync("barrier_id", absl::Seconds(5), task_0_, + /*participating_tasks=*/{}, [&](absl::Status s) { + barrier_status = s; + n0.Notify(); + }); + + n0.WaitForNotification(); + EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; +} + TEST_F(CoordinateTwoTasksTest, BarrierWithSubsetFailsIfServiceHasStopped) { EnableCoordinationService(/*has_service_to_client_connection=*/false); ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); @@ -1619,4 +1749,107 @@ TEST_F(CoordinateTwoTasksTest, UnavailableTaskCanReconnect) { TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_new_)); } + +TEST_F(CoordinateTwoTasksTest, + DoNotAllowPollForErrorIfHasServiceToClientConnection) { + EnableCoordinationService(/*has_service_to_client_connection=*/true); + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + std::vector statuses; + statuses.reserve(2); + + for (const CoordinatedTask& task : {task_0_, task_1_}) { + coord_service_->PollForErrorAsync( + task, [&](const absl::Status& status) { statuses.push_back(status); }); + } + + // The error polling requests will get immediate error because there is + // service to client connection. + EXPECT_EQ(statuses.size(), 2); + EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kInternal))); +} + +TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfNotInCluster) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + CoordinatedTask task_not_in_cluster; + absl::Status s; + + coord_service_->PollForErrorAsync( + task_not_in_cluster, [&](const absl::Status& status) { s = status; }); + + EXPECT_THAT(s, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not in the cluster"))); +} + +TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfTaskNotRegistered) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + absl::Status s; + + coord_service_->PollForErrorAsync( + task_0_, [&](const absl::Status& status) { s = status; }); + + EXPECT_THAT(s, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("has not been registered"))); +} + +TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfServiceHasStopped) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + // No heartbeat for a while, leader consider the task as stale. + // As no error propagation is available, service stops. + Env::Default()->SleepForMicroseconds( + absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); + + absl::Status s; + coord_service_->PollForErrorAsync( + task_0_, [&](const absl::Status& status) { s = status; }); + + EXPECT_THAT(s, StatusIs(absl::StatusCode::kInternal, + HasSubstr("service has shut down"))); +} + +TEST_F(CoordinateTwoTasksTest, + CanPropagateTaskRegistrationErrorThroughErrorPolling) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + absl::Status s0; + // Start polling for error from `task_0_`. + coord_service_->PollForErrorAsync( + task_0_, [&](const absl::Status& status) { s0 = status; }); + + // Let registration of `task_1_` fail due to incarnation mismatch. + ASSERT_THAT(coord_service_->RegisterTask(task_1_, incarnation_0_), + StatusIs(absl::StatusCode::kAborted)); + + // The first error polling request will get the error propagated from the + // registration failure. + EXPECT_THAT(s0, StatusIs(absl::StatusCode::kAborted)); +} + +TEST_F(CoordinateTwoTasksTest, LatePollingTaskCanGetError) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + std::vector statuses; + statuses.reserve(2); + coord_service_->PollForErrorAsync( + task_0_, [&](const absl::Status& status) { statuses.push_back(status); }); + + // Fail `task_0_` with an error because `task_1_` polls for error. + ASSERT_OK(coord_service_->ReportTaskError( + task_0_, absl::FailedPreconditionError("test_error_from_task_0"))); + + // Poll for error from `task_1_` after the error has been propagated to other + // tasks. + coord_service_->PollForErrorAsync( + task_1_, [&](const absl::Status& status) { statuses.push_back(status); }); + + // Make sure the error is propagated to both tasks. + EXPECT_EQ(statuses.size(), 2); + EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("test_error_from_task_0")))); +} + } // namespace tsl diff --git a/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc b/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc index 0e583f5cad8bf8..4afe13f2c7960d 100644 --- a/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc +++ b/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc @@ -57,6 +57,8 @@ using tensorflow::HeartbeatRequest; using tensorflow::HeartbeatResponse; using tensorflow::InsertKeyValueRequest; using tensorflow::InsertKeyValueResponse; +using tensorflow::PollForErrorRequest; +using tensorflow::PollForErrorResponse; using tensorflow::RegisterTaskRequest; using tensorflow::RegisterTaskResponse; using tensorflow::ReportErrorToServiceRequest; @@ -269,6 +271,17 @@ class GrpcCoordinationClient : public CoordinationClient { &target_); } + void PollForErrorAsync(CallOptions* call_opts, + const PollForErrorRequest* request, + PollForErrorResponse* response, + StatusCallback done) override { + new RPCState( + &stub_, cq_, "/tensorflow.CoordinationService/PollForError", *request, + response, std::move(done), call_opts, + /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, + &target_); + } + private: ::grpc::GenericStub stub_; ::grpc::CompletionQueue* cq_; diff --git a/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.cc b/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.cc index 72160cf8d1b0a3..d3187c291b2d92 100644 --- a/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.cc +++ b/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.cc @@ -57,6 +57,7 @@ void GrpcCoordinationServiceImpl::HandleRPCsLoop() { ENQUEUE_REQUEST(DeleteKeyValue); ENQUEUE_REQUEST(Barrier); ENQUEUE_REQUEST(CancelBarrier); + ENQUEUE_REQUEST(PollForError); #undef ENQUEUE_REQUEST void* tag; // Matches the operation started against this cq_. diff --git a/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h b/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h index 6551006d399b11..4b6c74e0b870af 100644 --- a/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h +++ b/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h @@ -98,6 +98,7 @@ class GrpcCoordinationServiceImpl : public AsyncServiceInterface { HANDLER(DeleteKeyValue); HANDLER(Barrier); HANDLER(CancelBarrier); + HANDLER(PollForError); #undef HANDLER thread::ThreadPool& compute_pool_; From 6c5d5686515f4ea2e505b9f7dd85cbde78f51b4f Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 25 Jul 2024 23:51:39 -0700 Subject: [PATCH 187/376] [XLA:GPU][NFC] Force `triton_support_test.cc` to run on GPU. This is a temporary measure to get OSS coverage while the test fails on CPU. The alternative is completely disabling it in OSS, which is not ideal. PiperOrigin-RevId: 656271136 --- build_tools/build.py | 10 ++-------- xla/service/gpu/fusions/triton/BUILD | 3 +++ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/build_tools/build.py b/build_tools/build.py index 848fbb407bef75..770ba19a9a885c 100755 --- a/build_tools/build.py +++ b/build_tools/build.py @@ -247,10 +247,7 @@ def nvidia_gpu_build_with_compute_capability( repo="openxla/xla", docker_image=_DEFAULT_IMAGE, configs=("warnings", "nonccl", "rbe_linux_cpu"), - target_patterns=_XLA_DEFAULT_TARGET_PATTERNS - + ( - "-//xla/service/gpu/fusions/triton:triton_support_test", - ), + target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, build_tag_filters=cpu_x86_tag_filter, test_tag_filters=cpu_x86_tag_filter, options=_DEFAULT_BAZEL_OPTIONS, @@ -268,10 +265,7 @@ def nvidia_gpu_build_with_compute_capability( repo="openxla/xla", docker_image=_ARM64_JAX_MULTI_PYTHON_IMAGE, configs=("warnings", "rbe_cross_compile_linux_arm64_xla", "nonccl"), - target_patterns=_XLA_DEFAULT_TARGET_PATTERNS - + ( - "-//xla/service/gpu/fusions/triton:triton_support_test", - ), + target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, options={**_DEFAULT_BAZEL_OPTIONS, "build_tests_only": True}, build_tag_filters=cpu_arm_tag_filter, test_tag_filters=cpu_arm_tag_filter, diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index e18844d5c3b074..4fd95ea732c9ca 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -417,6 +417,9 @@ xla_cc_test( name = "triton_support_test", srcs = ["triton_support_test.cc"], shard_count = 20, + # TODO(b/353912594): this test does not need to run on GPU, but it is broken on CPU in OSS. + # Force it to run on GPU temporarily in order to get important OSS coverage. + tags = ["gpu"], deps = [ ":triton_fusion_emitter", ":triton_support", From 4b64c853147c5bd28e08a8a6c4510b189e689a35 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 26 Jul 2024 02:00:12 -0700 Subject: [PATCH 188/376] PR #15328: [GPU][NFC] Remove duplicate validation of cuDNN graphs. Imported from GitHub PR https://github.com/openxla/xla/pull/15328 The other time it happens for these graphs in CudnnGraph::Prepare(): https://github.com/openxla/xla/blob/45dca1a0a1d87f3d3c93fa4175e1df971acddb10/xla/stream_executor/cuda/cuda_dnn.cc#L8359 Copybara import of the project: -- 2b979036bfa52c05a0ce541ab8a89e1f2e6834ee by Ilia Sergachev : [GPU][NFC] Remove duplicate validation of cuDNN graphs. The other time it happens for these graphs in CudnnGraph::Prepare(). Merging this change closes #15328 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15328 from openxla:remove_extra_validation 2b979036bfa52c05a0ce541ab8a89e1f2e6834ee PiperOrigin-RevId: 656309054 --- xla/service/gpu/cudnn_fusion_compiler.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/xla/service/gpu/cudnn_fusion_compiler.cc b/xla/service/gpu/cudnn_fusion_compiler.cc index f9ae751ef6949b..18067960f80211 100644 --- a/xla/service/gpu/cudnn_fusion_compiler.cc +++ b/xla/service/gpu/cudnn_fusion_compiler.cc @@ -584,10 +584,6 @@ absl::StatusOr> HloFusionToCuDnnGraph( absl::StrCat("cudnn_fusion_", fusion.name(), ".json"), /*contents=*/dump.dump(1)); } - if (cudnn_frontend::error_t result = graph.validate(); result.is_bad()) { - VLOG(3) << result.get_message(); - return std::nullopt; - } return se::gpu::CudnnGraph(std::move(graph)); } From 27d191733c3fe1ca313837da440b2fd7de604b02 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 26 Jul 2024 04:06:31 -0700 Subject: [PATCH 189/376] Conditionally use Shardy in XLA CPU pipeline. PiperOrigin-RevId: 656343554 --- xla/service/cpu/BUILD | 1 + xla/service/cpu/cpu_compiler.cc | 14 +++++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index eef125f44cfa5f..cf93548c61563d 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -332,6 +332,7 @@ cc_library( "//xla/service/llvm_ir:llvm_command_line_options", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd:stateful_rng_spmd_partitioner", + "//xla/service/spmd/shardy:shardy_xla_pass", "//xla/stream_executor", "//xla/stream_executor/host:host_platform_id", "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index e19b90d7c8f96e..73201a657bc3f8 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -168,6 +168,7 @@ limitations under the License. #include "xla/service/sharding_remover.h" #include "xla/service/slow_operation_alarm.h" #include "xla/service/sort_simplifier.h" +#include "xla/service/spmd/shardy/shardy_xla_pass.h" #include "xla/service/spmd/stateful_rng_spmd_partitioner.h" #include "xla/service/stochastic_convert_decomposer.h" #include "xla/service/sub_byte_normalization.h" @@ -446,11 +447,14 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( spmd_pipeline.AddPass(); spmd_pipeline.AddPass(); spmd_pipeline.AddPass(); - - spmd_pipeline.AddPass( - /*is_spmd=*/true, /*propagate_metadata=*/false, - module->config().allow_spmd_sharding_propagation_to_output(), - module->config().allow_spmd_sharding_propagation_to_parameters()); + if (module->config().debug_options().xla_use_shardy()) { + spmd_pipeline.AddPass(); + } else { + spmd_pipeline.AddPass( + /*is_spmd=*/true, /*propagate_metadata=*/false, + module->config().allow_spmd_sharding_propagation_to_output(), + module->config().allow_spmd_sharding_propagation_to_parameters()); + } spmd_pipeline.AddPass( num_partitions, module->config().replica_count()); TF_RETURN_IF_ERROR(spmd_pipeline.Run(module).status()); From 723d125e66d5bb01e970be5280c8036ececce7b7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Jul 2024 04:36:58 -0700 Subject: [PATCH 190/376] Integrate LLVM at llvm/llvm-project@51d4980a133d Updates LLVM usage to match [51d4980a133d](https://github.com/llvm/llvm-project/commit/51d4980a133d) PiperOrigin-RevId: 656351435 --- third_party/llvm/generated.patch | 40 ++++++++++--------- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/workspace.bzl | 4 +- .../tsl/third_party/llvm/generated.patch | 40 ++++++++++--------- .../tsl/third_party/llvm/workspace.bzl | 4 +- 5 files changed, 48 insertions(+), 44 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 4eda7b241d21bc..506f5632703a41 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,21 +1,23 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -@@ -675,6 +675,7 @@ - deps = [ - ":__support_common", - ":__support_cpp_type_traits", -+ ":__support_fputil_dyadic_float", - ":__support_fputil_fenv_impl", - ":__support_fputil_fp_bits", - ":__support_macros_optimization", -@@ -1089,7 +1090,7 @@ - ":__support_macros_optimization", - ":__support_osutil_syscall", - ":types_pid_t", -- ] -+ ], - ) +diff -ruN --strip-trailing-cr a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h ++++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +@@ -9,6 +9,7 @@ + #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ - libc_support_library( + #include "mlir/Conversion/LLVMCommon/Pattern.h" ++#include "mlir/Dialect/Arith/IR/Arith.h" + #include "mlir/Dialect/GPU/IR/GPUDialect.h" + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" + #include "mlir/IR/Builders.h" +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +@@ -5744,6 +5744,7 @@ + "lib/Conversion/GPUCommon/OpToFuncCallLowering.h", + ], + deps = [ ++ ":ArithDialect", + ":GPUDialect", + ":IR", + ":LLVMCommonConversion", diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index a6b1b06abe37c5..45f9cf544dc10c 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "58fb51492d9669525662fa269295d85537968569" - LLVM_SHA256 = "f6cac3f3f562a7bd3a36a828df2960a1ebc2cd6237f4cb95a66f1bd16e918ef9" + LLVM_COMMIT = "51d4980a133db12888207698e39c469cb7055cac" + LLVM_SHA256 = "ee34426de8adf8408a610d0072e82b50bad0adac2c009f1f20072d626c0b876e" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 01e4a8bf6970a8..3c78b846facc61 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "f7ba97a90be022a20dc0f970998bc0855f152314" - SHARDY_SHA256 = "6dcf7672c93ed22fa676ab8d33e4d5b64eff6cee4668098f0937fb57cb8f1320" + SHARDY_COMMIT = "effc9ac0716b25861f7deaea91aafaa93515a1aa" + SHARDY_SHA256 = "cce9c625b2ce107c2ab19e811059bf1d3da0160fdbe418778658a8f19fef211a" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index 4eda7b241d21bc..506f5632703a41 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1,21 +1,23 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel -@@ -675,6 +675,7 @@ - deps = [ - ":__support_common", - ":__support_cpp_type_traits", -+ ":__support_fputil_dyadic_float", - ":__support_fputil_fenv_impl", - ":__support_fputil_fp_bits", - ":__support_macros_optimization", -@@ -1089,7 +1090,7 @@ - ":__support_macros_optimization", - ":__support_osutil_syscall", - ":types_pid_t", -- ] -+ ], - ) +diff -ruN --strip-trailing-cr a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h ++++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +@@ -9,6 +9,7 @@ + #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ - libc_support_library( + #include "mlir/Conversion/LLVMCommon/Pattern.h" ++#include "mlir/Dialect/Arith/IR/Arith.h" + #include "mlir/Dialect/GPU/IR/GPUDialect.h" + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" + #include "mlir/IR/Builders.h" +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +@@ -5744,6 +5744,7 @@ + "lib/Conversion/GPUCommon/OpToFuncCallLowering.h", + ], + deps = [ ++ ":ArithDialect", + ":GPUDialect", + ":IR", + ":LLVMCommonConversion", diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index a6b1b06abe37c5..45f9cf544dc10c 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "58fb51492d9669525662fa269295d85537968569" - LLVM_SHA256 = "f6cac3f3f562a7bd3a36a828df2960a1ebc2cd6237f4cb95a66f1bd16e918ef9" + LLVM_COMMIT = "51d4980a133db12888207698e39c469cb7055cac" + LLVM_SHA256 = "ee34426de8adf8408a610d0072e82b50bad0adac2c009f1f20072d626c0b876e" tf_http_archive( name = name, From 0a90c5fecd7028e96da448d7ca71deb1ca6f9c38 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Fri, 26 Jul 2024 04:46:04 -0700 Subject: [PATCH 191/376] Add missing patch after latest LLVM integrate PiperOrigin-RevId: 656353550 --- .../triton/llvm_integration/cl656020169.patch | 12 ++++++++++++ third_party/triton/llvm_integration/series.bzl | 1 + 2 files changed, 13 insertions(+) create mode 100644 third_party/triton/llvm_integration/cl656020169.patch diff --git a/third_party/triton/llvm_integration/cl656020169.patch b/third_party/triton/llvm_integration/cl656020169.patch new file mode 100644 index 00000000000000..7586a90b14ccf6 --- /dev/null +++ b/third_party/triton/llvm_integration/cl656020169.patch @@ -0,0 +1,12 @@ +diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +@@ -117,7 +117,7 @@ private: + auto operands = callOp.getOperands(); + auto result = callOp.getResult(); + +- LLVM::LLVMFunctionType calleeType = callOp.getCalleeType().value(); ++ LLVM::LLVMFunctionType calleeType = callOp.getVarCalleeType().value(); + Type returnType = calleeType.getReturnType(); + + auto loc = callOp.getLoc(); diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 656b9c894904d8..9d0e1204ba527f 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,5 +8,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ + "//third_party/triton/llvm_integration:cl656020169.patch", # Add new patches just above this line ] From 4962386416b8da348f7a86651b15bdcde726edfd Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Fri, 26 Jul 2024 05:21:48 -0700 Subject: [PATCH 192/376] IndexingMapAttr: use aliases to print it always at the top & with a new line. PiperOrigin-RevId: 656362264 --- .../gpu/fusions/mlir/ir/xla_gpu_attrs.cc | 2 +- .../gpu/fusions/mlir/ir/xla_gpu_ops.cc | 14 ++++++- .../fusions/mlir/tests/indexing_map_attr.mlir | 39 ++++++++++++++----- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc index 2f51b2572831f7..ad31f42c64bc84 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc @@ -154,7 +154,7 @@ mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { } void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { - printer << "<"; + printer << "<\n"; printer.printStrippedAttrOrType(getMap()); printer << "\ndomain:\n"; PrintDimVars(printer, getDimVars()); diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc index c7a0575a0ef087..2c4029c5c20d9a 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -148,6 +148,18 @@ struct XlaGpuInlinerInterface : public mlir::DialectInlinerInterface { } }; +struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + AliasResult getAlias(mlir::Attribute attr, + mlir::raw_ostream& os) const final { + if (llvm::isa(attr)) { + os << "indexing_map"; + return AliasResult::FinalAlias; + } + return AliasResult::NoAlias; + } +}; + } // namespace void XlaGpuDialect::initialize() { @@ -161,7 +173,7 @@ void XlaGpuDialect::initialize() { #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" >(); #undef GET_ATTRDEF_LIST - addInterfaces(); + addInterfaces(); } LogicalResult PureCallOp::verifySymbolUses( diff --git a/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir b/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir index f6228b07aab50f..3ea853dc8d0d19 100644 --- a/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir +++ b/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir @@ -1,6 +1,7 @@ -// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt | FileCheck %s +// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt -split-input-file | FileCheck %s -// CHECK: #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-NEXT: (d0, d1, d2)[s0] -> (d0) // CHECK-NEXT: domain: // CHECK-NEXT: d0 in [1, 2] // CHECK-NEXT: d1 in [5, 8] @@ -20,10 +21,13 @@ > func.func private @indexing_map_attr(tensor<32xf64, #map>) +// CHECK-LABEL: @indexing_map_attr +// CHECK: tensor<32xf64, #[[$INDEX_MAP]]> // ----- -// CHECK: #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-NEXT: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) // CHECK-NEXT: domain: // CHECK-NEXT: d0 in [1, 2] // CHECK-NEXT: d1 in [5, 8] @@ -46,10 +50,13 @@ func.func private @indexing_map_attr(tensor<32xf64, #map>) d1 + s1 + s2 in [1, 32] > func.func private @more_range_vars(tensor<32xf64, #map>) +// CHECK-LABEL: @more_range_vars +// CHECK: tensor<32xf64, #[[$INDEX_MAP]]> // ----- -// CHECK: #xla_gpu.indexing_map<(d0)[s0] -> (d0) +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-NEXT: (d0)[s0] -> (d0) // CHECK-NEXT: domain: // CHECK-NEXT: d0 in [0, 100] // CHECK-NEXT: s0 in [-3, -1] @@ -60,10 +67,13 @@ func.func private @more_range_vars(tensor<32xf64, #map>) s0 in [-3, -1] > func.func private @indexing_map_small(tensor<100xf64, #map>) +// CHECK-LABEL: @indexing_map_small +// CHECK: tensor<100xf64, #[[$INDEX_MAP]]> // ----- -// CHECK: #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-NEXT: (d0, d1, d2)[s0] -> (d0) // CHECK-NEXT: domain: // CHECK-NEXT: d0 in [1, 2] // CHECK-NEXT: d1 in [5, 8] @@ -78,10 +88,13 @@ func.func private @indexing_map_small(tensor<100xf64, #map>) s0 in [0, 32] > func.func private @no_constraints(tensor<32xf64, #map>) +// CHECK-LABEL: @no_constraints +// CHECK: tensor<32xf64, #[[$INDEX_MAP]]> // ----- -// CHECK: #xla_gpu.indexing_map<()[s0] -> (s0) +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-NEXT: ()[s0] -> (s0) // CHECK-NEXT: domain: // CHECK-NEXT: s0 in [3, 5] // CHECK-NEXT: s0 mod 2 in [0, 1] @@ -92,10 +105,13 @@ func.func private @no_constraints(tensor<32xf64, #map>) s0 mod 2 in [0, 1] > func.func private @no_dimensions(tensor<100xf64, #map>) +// CHECK-LABEL: @no_dimensions +// CHECK: tensor<100xf64, #[[$INDEX_MAP]]> // ----- -// CHECK: #xla_gpu.indexing_map<(d0) -> (d0) +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-NEXT: (d0) -> (d0) // CHECK-NEXT: domain: // CHECK-NEXT: d0 in [3, 5] // CHECK-NEXT: d0 mod 2 in [0, 1] @@ -106,13 +122,18 @@ func.func private @no_dimensions(tensor<100xf64, #map>) d0 mod 2 in [0, 1] > func.func private @no_symbols(tensor<100xf64, #map>) +// CHECK-LABEL: @no_symbols +// CHECK: tensor<100xf64, #[[$INDEX_MAP]]> // ----- -// CHECK: #xla_gpu.indexing_map<() -> () +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-NEXT: () -> () // CHECK-NEXT: domain: // CHECK-NEXT: > #map = #xla_gpu.indexing_map<() -> () domain: > -func.func private @empty(tensor<100xf64, #map>) \ No newline at end of file +func.func private @empty(tensor<100xf64, #map>) +// CHECK-LABEL: @empty +// CHECK: tensor<100xf64, #[[$INDEX_MAP]]> \ No newline at end of file From 6f6af02ba768918e4efecfe6f32ff7ce3a3b5ea7 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 26 Jul 2024 06:18:41 -0700 Subject: [PATCH 193/376] [XLA:GPU][MLIR-based emitters] Move XlaGpuDialect to xla_gpu_dialect.cc. PiperOrigin-RevId: 656376390 --- xla/service/gpu/fusions/mlir/ir/BUILD | 1 + .../gpu/fusions/mlir/ir/xla_gpu_dialect.cc | 129 ++++++++++++++++++ .../gpu/fusions/mlir/ir/xla_gpu_ops.cc | 101 -------------- xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h | 2 +- 4 files changed, 131 insertions(+), 102 deletions(-) create mode 100644 xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc diff --git a/xla/service/gpu/fusions/mlir/ir/BUILD b/xla/service/gpu/fusions/mlir/ir/BUILD index ba19b9b81b5a4e..d618413f13e817 100644 --- a/xla/service/gpu/fusions/mlir/ir/BUILD +++ b/xla/service/gpu/fusions/mlir/ir/BUILD @@ -83,6 +83,7 @@ cc_library( name = "xla_gpu", srcs = [ "xla_gpu_attrs.cc", + "xla_gpu_dialect.cc", "xla_gpu_ops.cc", ], hdrs = [ diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc new file mode 100644 index 00000000000000..3dc60c91f40779 --- /dev/null +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc @@ -0,0 +1,129 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep +#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep +#include "mlir/Transforms/InliningUtils.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#define GET_ATTRDEF_CLASSES +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" +#undef GET_ATTRDEF_CLASSES + +namespace xla { +namespace gpu { +namespace { + +struct XlaGpuInlinerInterface : public mlir::DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + // Returns true if the given operation 'callable', that implements the + // 'CallableOpInterface', can be inlined into the position given call + // operation 'call', that is registered to the current dialect and implements + // the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the + // given 'callable' is set to be cloned during the inlining process, or false + // if the region is set to be moved in-place (i.e. no duplicates would be + // created). + bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, + bool wouldBeCloned) const final { + if (!wouldBeCloned) { + // If no duplicate would be created, 'call' is likely the only caller of + // 'callable'. + return true; + } + // Otherwise, inline only if the called function is small. We could + // theoretically also inline if there is no other caller in the function + // that contains the callee that has a call path to the callable, but that + // is more expensive to check. + auto func_op = mlir::dyn_cast(callable); + if (!func_op) { + return false; + } + auto region = func_op.getCallableRegion(); + if (!region) { + return false; + } + + // If callee and caller call the same third function, inline. We have no + // guarantee that the indices are the same, but there is a good chance they + // are (or if the callee gets inlined as well, there will be CSE + // opportunities). + // This is duct tape to work around the limitations of our partitioner. + // Ideally, the partitioner would be aware of the actual indexing and create + // the partitions based on it (i.e., the case where the indices are the same + // would never happen). + llvm::SmallDenseSet callee_calls; + for (auto call : region->getOps()) { + callee_calls.insert(call.getCallee()); + } + for (auto call : call->getParentRegion()->getOps()) { + if (callee_calls.contains(call.getCallee())) { + return true; + } + } + + constexpr int kMaxOperationsToInline = 8; + int num_ops = 0; + region->front().walk([&](mlir::Operation* op) { ++num_ops; }); + + // Don't inline functions that are called more than once and contain more + // than one call themselves. + return num_ops <= kMaxOperationsToInline; + } + // Returns true if the given operation 'op', that is registered to this + // dialect, can be inlined into the given region, false otherwise. + // 'wouldBeCloned' is set to true if the given 'op' is set to be cloned + // during the inlining process, or false if the operation is set to be moved + // in-place(i.e. no duplicates would be created). 'valueMapping' contains any + // remapped values from within the 'src' region. This can be used to examine + // what values may potentially replace the operands to 'op'. + bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, + bool wouldBeCloned, + mlir::IRMapping& valueMapping) const final { + // We allow any op from the xla_gpu dialect to be inlined. + return true; + } +}; + +struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + AliasResult getAlias(mlir::Attribute attr, + mlir::raw_ostream& os) const final { + if (llvm::isa(attr)) { + os << "indexing_map"; + return AliasResult::FinalAlias; + } + return AliasResult::NoAlias; + } +}; + +} // namespace + +void XlaGpuDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc" +#undef GET_OP_LIST + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" + >(); +#undef GET_ATTRDEF_LIST + addInterfaces(); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc index 2c4029c5c20d9a..dfa4d056a80bda 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -44,14 +44,9 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/InliningUtils.h" #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc.inc" #include "xla/service/gpu/model/indexing_map.h" -#define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" -#undef GET_ATTRDEF_CLASSES - namespace xla { namespace gpu { namespace { @@ -78,104 +73,8 @@ using mlir::ValueRange; namespace arith = mlir::arith; -struct XlaGpuInlinerInterface : public mlir::DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - // Returns true if the given operation 'callable', that implements the - // 'CallableOpInterface', can be inlined into the position given call - // operation 'call', that is registered to the current dialect and implements - // the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the - // given 'callable' is set to be cloned during the inlining process, or false - // if the region is set to be moved in-place (i.e. no duplicates would be - // created). - bool isLegalToInline(mlir::Operation* call, mlir::Operation* callable, - bool wouldBeCloned) const final { - if (!wouldBeCloned) { - // If no duplicate would be created, 'call' is likely the only caller of - // 'callable'. - return true; - } - // Otherwise, inline only if the called function is small. We could - // theoretically also inline if there is no other caller in the function - // that contains the callee that has a call path to the callable, but that - // is more expensive to check. - auto func_op = mlir::dyn_cast(callable); - if (!func_op) { - return false; - } - auto region = func_op.getCallableRegion(); - if (!region) { - return false; - } - - // If callee and caller call the same third function, inline. We have no - // guarantee that the indices are the same, but there is a good chance they - // are (or if the callee gets inlined as well, there will be CSE - // opportunities). - // This is duct tape to work around the limitations of our partitioner. - // Ideally, the partitioner would be aware of the actual indexing and create - // the partitions based on it (i.e., the case where the indices are the same - // would never happen). - llvm::SmallDenseSet callee_calls; - for (auto call : region->getOps()) { - callee_calls.insert(call.getCallee()); - } - for (auto call : call->getParentRegion()->getOps()) { - if (callee_calls.contains(call.getCallee())) { - return true; - } - } - - constexpr int kMaxOperationsToInline = 8; - int num_ops = 0; - region->front().walk([&](mlir::Operation* op) { ++num_ops; }); - - // Don't inline functions that are called more than once and contain more - // than one call themselves. - return num_ops <= kMaxOperationsToInline; - } - // Returns true if the given operation 'op', that is registered to this - // dialect, can be inlined into the given region, false otherwise. - // 'wouldBeCloned' is set to true if the given 'op' is set to be cloned - // during the inlining process, or false if the operation is set to be moved - // in-place(i.e. no duplicates would be created). 'valueMapping' contains any - // remapped values from within the 'src' region. This can be used to examine - // what values may potentially replace the operands to 'op'. - bool isLegalToInline(mlir::Operation* op, mlir::Region* dest, - bool wouldBeCloned, - mlir::IRMapping& valueMapping) const final { - // We allow any op from the xla_gpu dialect to be inlined. - return true; - } -}; - -struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface { - using OpAsmDialectInterface::OpAsmDialectInterface; - AliasResult getAlias(mlir::Attribute attr, - mlir::raw_ostream& os) const final { - if (llvm::isa(attr)) { - os << "indexing_map"; - return AliasResult::FinalAlias; - } - return AliasResult::NoAlias; - } -}; - } // namespace -void XlaGpuDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc" -#undef GET_OP_LIST - >(); - addAttributes< -#define GET_ATTRDEF_LIST -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" - >(); -#undef GET_ATTRDEF_LIST - addInterfaces(); -} - LogicalResult PureCallOp::verifySymbolUses( mlir::SymbolTableCollection& symbolTable) { auto callee = getCalleeAttr(); diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h index 02604c1ea99db7..f43786f4fde0ac 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h @@ -28,8 +28,8 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" // IWYU pragma: keep -#define GET_OP_CLASSES #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc" +#define GET_OP_CLASSES #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h.inc" #undef GET_OP_CLASSES #define GET_ATTRDEF_CLASSES From 7cad716923f6b8588512cc1a2882ae2e5ee2fe9d Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Fri, 26 Jul 2024 06:31:02 -0700 Subject: [PATCH 194/376] [XLA:GPU] Add Custom Kernel Fusion Autotuner HLO pass. PiperOrigin-RevId: 656379223 --- xla/service/gpu/BUILD | 92 ++++++++ .../gpu/custom_kernel_fusion_autotuner.cc | 220 ++++++++++++++++++ .../gpu/custom_kernel_fusion_autotuner.h | 53 +++++ .../custom_kernel_fusion_autotuner_test.cc | 112 +++++++++ xla/service/gpu/gpu_compiler.cc | 2 + 5 files changed, 479 insertions(+) create mode 100644 xla/service/gpu/custom_kernel_fusion_autotuner.cc create mode 100644 xla/service/gpu/custom_kernel_fusion_autotuner.h create mode 100644 xla/service/gpu/custom_kernel_fusion_autotuner_test.cc diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 28ca8469086ce3..1ff6ab879769b6 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2905,6 +2905,97 @@ xla_cc_test( ], ) +cc_library( + name = "custom_kernel_fusion_autotuner", + srcs = if_cuda_is_configured(["custom_kernel_fusion_autotuner.cc"]), + hdrs = if_cuda_is_configured(["custom_kernel_fusion_autotuner.h"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = if_cuda_is_configured([ + ":autotuner_compile_util", + ":autotuner_util", + ":backend_configs_cc", + ":buffer_comparator", + ":gemm_rewriter", + ":gpu_float_support", + ":gpu_fusible", + ":instruction_fusion", + ":ir_emission_utils", + ":matmul_utils", + ":split_k_gemm_rewriter", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/service/gpu/kernels:custom_kernel_fusion", + ":stream_executor_util", + ":cudnn_fusion_compiler", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_config_cuda//cuda:cuda_headers", + "//xla:autotuning_proto_cc", + "//xla:shape_util", + "//xla:status_macros", + "//xla/tools:hlo_decomposer_lib", + "//xla:statusor", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:algorithm_util", + "//xla/service:dump", + "//xla/service:executable", + "//xla/service:float_normalization", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service:shaped_buffer", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", + "//xla/stream_executor", + "//xla/stream_executor/gpu:redzone_allocator", + "@tsl//tsl/lib/core:bits", + "@tsl//tsl/platform:blocking_counter", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:status", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/profiler/lib:scoped_annotation", + "//xla/tsl/util/proto:proto_utils", + "//xla/service/gpu:hlo_traversal", + ]) + [ + "//xla/stream_executor:stream_executor_memory_allocator", + "@com_google_absl//absl/status", + "@tsl//tsl/platform:path", + ], +) + +xla_test( + name = "custom_kernel_fusion_autotuner_test", + srcs = if_cuda_is_configured(["custom_kernel_fusion_autotuner_test.cc"]), + backends = [ + "gpu", + ], + deps = [ + ":autotuner_util", + ":custom_kernel_fusion_autotuner", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass_pipeline", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", + "@tsl//tsl/platform:path", + "@tsl//tsl/platform:test", + ], +) + cc_library( name = "custom_kernel_fusion_rewriter", srcs = ["custom_kernel_fusion_rewriter.cc"], @@ -3075,6 +3166,7 @@ cc_library( ":compile_module_to_llvm_ir", ":conv_layout_normalization", ":copy_fusion", + ":custom_kernel_fusion_autotuner", ":custom_kernel_fusion_rewriter", ":dot_dimension_sorter", ":dot_operand_converter", diff --git a/xla/service/gpu/custom_kernel_fusion_autotuner.cc b/xla/service/gpu/custom_kernel_fusion_autotuner.cc new file mode 100644 index 00000000000000..d5114bc6e6edbc --- /dev/null +++ b/xla/service/gpu/custom_kernel_fusion_autotuner.cc @@ -0,0 +1,220 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/custom_kernel_fusion_autotuner.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/executable.h" +#include "xla/service/gpu/autotuner_compile_util.h" +#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion.h" +#include "xla/service/shaped_buffer.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tools/hlo_decomposer.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { +absl::StatusOr> ExtractFusionModule( + HloInstruction* fusion_instruction, int64_t kernel_index) { + std::unique_ptr hlo_module = + ExtractInstructionIntoNewModule(*fusion_instruction); + + HloInstruction* instruction = + hlo_module->entry_computation()->root_instruction(); + GpuBackendConfig gpu_config = + instruction->backend_config().value(); + gpu_config.mutable_fusion_backend_config() + ->mutable_custom_fusion_config() + ->set_kernel_index(kernel_index); + TF_RETURN_IF_ERROR(instruction->set_backend_config(gpu_config)); + + return hlo_module; +} + +absl::StatusOr>> ProfileKernels( + std::vector& kernels, HloInstruction* fusion_instruction, + AutotunerCompileUtil& compile_util, const AutotuneConfig& autotune_config, + const DebugOptions& debug_options) { + se::StreamExecutor* stream_exec = autotune_config.GetExecutor(); + std::vector> results; + for (int i = 0; i < kernels.size(); ++i) { + TF_ASSIGN_OR_RETURN(absl::StatusOr> executable, + compile_util.Compile([&](const DebugOptions& opt) { + return ExtractFusionModule(fusion_instruction, i); + })); + + se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator(); + std::unique_ptr owned_allocator; + if (allocator == nullptr) { + owned_allocator = + std::make_unique(stream_exec); + allocator = owned_allocator.get(); + } + TF_ASSIGN_OR_RETURN(se::Stream* const stream, autotune_config.GetStream()); + + TF_ASSIGN_OR_RETURN(auto rz_buffers, + RedzoneBuffers::FromInstruction( + *fusion_instruction, autotune_config, debug_options, + RedzoneBuffers::kAllInputs)); + + std::optional reference_buffer; + std::optional profiling_output; + TF_ASSIGN_OR_RETURN(profiling_output, compile_util.ProfileExecutable( + executable->get(), stream, + rz_buffers.input_buffers(), + rz_buffers.input_shapes())); + results.push_back({i, profiling_output->duration}); + } + return results; +} + +absl::StatusOr FindFastestKernel( + const std::vector>& results) { + auto iter = absl::c_min_element( + results, [](const std::tuple& lhs, + const std::tuple& rhs) { + return std::get<1>(lhs) < std::get<1>(rhs); + }); + if (iter == results.end()) { + return absl::InternalError("Failed to find fastest kernel."); + } + return std::get<0>(*iter); +} + +absl::Status UpdateFusionInstructionKernelIndex( + HloInstruction* fusion_instruction, int kernel_index) { + GpuBackendConfig gpu_config = + fusion_instruction->backend_config().value(); + gpu_config.mutable_fusion_backend_config() + ->mutable_custom_fusion_config() + ->set_kernel_index(kernel_index); + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config)); + + return absl::OkStatus(); +} + +absl::StatusOr> LoadKernels( + const HloInstruction* fusion_instruction, + const AutotuneConfig& autotune_config) { + auto config = fusion_instruction->backend_config() + ->fusion_backend_config() + .custom_fusion_config(); + auto* registry = CustomKernelFusionRegistry::Default(); + auto* custom_kernel_fusion = registry->Lookup(config.name()); + + // If custom fusion is not found it means that some of the build targets might + // not be statically linked into the binary. + if (custom_kernel_fusion == nullptr) { + return absl::InternalError( + absl::StrCat("Custom kernel fusion ", config.name(), + " not found in a default registry.")); + } + + se::StreamExecutor* stream_exec = autotune_config.GetExecutor(); + if (!stream_exec->SynchronizeAllActivity()) { + return Internal("Failed to synchronize GPU for autotuning."); + } + se::DeviceDescription device_description = + stream_exec->GetDeviceDescription(); + + // Load custom kernels that can implement a fusion computation. + TF_ASSIGN_OR_RETURN( + std::vector kernels, + custom_kernel_fusion->LoadKernels( + device_description, + fusion_instruction->fused_instructions_computation())); + + return kernels; +} + +absl::StatusOr AutotuneCustomKernelFusion( + HloInstruction* fusion_instruction, const AutotuneConfig& autotune_config, + AutotunerCompileUtil& compile_util, const DebugOptions& debug_options) { + int previous_kernel_index = + fusion_instruction->backend_config() + ->fusion_backend_config() + .custom_fusion_config() + .kernel_index(); + + TF_ASSIGN_OR_RETURN(std::vector kernels, + LoadKernels(fusion_instruction, autotune_config)); + + std::vector> results; + TF_ASSIGN_OR_RETURN(results, + ProfileKernels(kernels, fusion_instruction, compile_util, + autotune_config, debug_options)); + + TF_ASSIGN_OR_RETURN(int fastest_kernel_index, FindFastestKernel(results)); + + TF_RETURN_IF_ERROR(UpdateFusionInstructionKernelIndex(fusion_instruction, + fastest_kernel_index)); + + return previous_kernel_index != fastest_kernel_index; +} +} // namespace + +absl::StatusOr CustomKernelFusionAutotuner::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + const DebugOptions& debug_options = module->config().debug_options(); + TF_ASSIGN_OR_RETURN(std::optional compile_util, + AutotunerCompileUtil::Create(config_, debug_options)); + TF_RET_CHECK(compile_util.has_value()); + + bool hlo_changed = false; + for (const HloComputation* computation : module->computations()) { + if (computation->IsFusionComputation()) { + TF_ASSIGN_OR_RETURN( + bool instruction_changed, + AutotuneCustomKernelFusion(computation->FusionInstruction(), config_, + compile_util.value(), debug_options)); + if (instruction_changed) { + hlo_changed = true; + } + } + } + + return hlo_changed; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/custom_kernel_fusion_autotuner.h b/xla/service/gpu/custom_kernel_fusion_autotuner.h new file mode 100644 index 00000000000000..f6cd0c0fa5b6d1 --- /dev/null +++ b/xla/service/gpu/custom_kernel_fusion_autotuner.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_ +#define XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/autotuning.pb.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/xla.pb.h" + +namespace xla { +namespace gpu { + +// Find best custom kernel for custom kernel fusions. +class CustomKernelFusionAutotuner : public HloModulePass { + public: + explicit CustomKernelFusionAutotuner(const AutotuneConfig& config) + : config_(config) {} + + absl::string_view name() const override { + return "custom_kernel-fusion-autotuner"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const AutotuneConfig config_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_ diff --git a/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc b/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc new file mode 100644 index 00000000000000..aa6c1d2ffa46c3 --- /dev/null +++ b/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/custom_kernel_fusion_autotuner.h" + +#include +#include +#include + +#include +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class CustomKernelFusionAutotunerTest : public HloTestBase { + public: + CustomKernelFusionAutotunerTest() + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} + + void SetUp() override { HloTestBase::SetUp(); } + + void TearDown() override { HloTestBase::TearDown(); } +}; + +TEST_F(CustomKernelFusionAutotunerTest, + CustomKernelFusionAutotunerPassSucceeds) { + const std::string hlo_string = R"( + HloModule extracted + + cutlass_gemm { + p0 = f32[15,19]{1,0} parameter(0) + p1 = f32[19,17]{1,0} parameter(1) + ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY region_198.14436 { + p.0 = f32[15,19]{1,0} parameter(0) + p.1 = f32[19,17]{1,0} parameter(1) + ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom, calls=cutlass_gemm, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":0}},"force_earliest_schedule":false} + } + )"; + std::unique_ptr hlo_module = + ParseAndReturnVerifiedModule(hlo_string).value(); + + HloPassPipeline pipeline("custom_kernel_fusion_autotuner"); + DebugOptions debug_options; + AutotuneConfig autotune_config = + AutotuneConfig{DeviceConfig{backend().default_stream_executor(), + backend().memory_allocator()}, + debug_options}; + pipeline.AddPass(autotune_config); + ASSERT_TRUE(pipeline.Run(hlo_module.get()).ok()); +} + +TEST_F(CustomKernelFusionAutotunerTest, + CustomKernelFusionAutotunerPassUpdatesUpdatesKernelIndex) { + const std::string hlo_string = R"( + HloModule extracted + + cutlass_gemm { + p0 = f32[15,19]{1,0} parameter(0) + p1 = f32[19,17]{1,0} parameter(1) + ROOT r = f32[15, 17]{1,0} dot(p0, p1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + + ENTRY region_198.14436 { + p.0 = f32[15,19]{1,0} parameter(0) + p.1 = f32[19,17]{1,0} parameter(1) + ROOT cutlass_gemm = f32[15,17]{1,0} fusion(p.0, p.1), kind=kCustom, + calls=cutlass_gemm, + backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"cutlass_gemm","kernel_index":-1}},"force_earliest_schedule":false} + } + )"; + + HloPassPipeline pipeline("custom_kernel_fusion_autotuner"); + DebugOptions debug_options; + AutotuneConfig autotune_config = + AutotuneConfig{DeviceConfig{backend().default_stream_executor(), + backend().memory_allocator()}, + debug_options}; + pipeline.AddPass(autotune_config); + + std::string expected = R"( + CHECK: "kernel_index":0 + )"; + RunAndFilecheckHloRewrite(hlo_string, std::move(pipeline), expected); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 2304593b29af46..676ed086bd163f 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -118,6 +118,7 @@ limitations under the License. #include "xla/service/gpu/command_buffer_scheduling.h" #include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/conv_layout_normalization.h" +#include "xla/service/gpu/custom_kernel_fusion_autotuner.h" #include "xla/service/gpu/custom_kernel_fusion_rewriter.h" #include "xla/service/gpu/dot_dimension_sorter.h" #include "xla/service/gpu/dot_operand_converter.h" @@ -1375,6 +1376,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( if (debug_options.xla_gpu_enable_custom_fusions()) { pipeline.AddPass( &gpu_target_config.device_description); + pipeline.AddPass(autotune_config); } // Rewrite GEMMs into custom calls. From 460d22ef872b6e8d5dede04b7afb8590c98a7fa6 Mon Sep 17 00:00:00 2001 From: "Hoeseong (Hayden) Kim" Date: Fri, 26 Jul 2024 06:46:46 -0700 Subject: [PATCH 195/376] Add proto pyclif build rules PiperOrigin-RevId: 656383355 --- xla/BUILD | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xla/BUILD b/xla/BUILD index 3423d021354682..59fad44977cffc 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -2,6 +2,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load( "@tsl//tsl/platform:build_config.bzl", "tf_proto_library", + # copybara:uncomment "tf_pyclif_proto_library", ) load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load( @@ -1225,6 +1226,14 @@ xla_py_proto_library( ], ) +# copybara:uncomment_begin(google-only) +# tf_pyclif_proto_library( +# name = "autotune_results_pyclif", +# proto_lib = ":autotune_results_proto", +# visibility = ["//visibility:public"], +# ) +# copybara:uncomment_end + tf_proto_library( name = "autotuning_proto", srcs = ["autotuning.proto"], From ed996ff5e1909bcc610242bc53177ac2a9c67ed9 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 26 Jul 2024 07:36:25 -0700 Subject: [PATCH 196/376] [XLA:GPU] Annotate instructions with their scheduling names. PiperOrigin-RevId: 656396145 --- xla/hlo/ir/hlo_instruction.cc | 3 +- xla/service/gpu/BUILD | 32 +++++- xla/service/gpu/gpu_hlo_schedule.cc | 4 +- xla/service/gpu/pipelined_p2p_rewriter.cc | 2 + .../gpu/pipelined_p2p_rewriter_test.cc | 52 +++++----- .../gpu/scheduling_instruction_annotator.cc | 67 +++++++++++++ .../gpu/scheduling_instruction_annotator.h | 44 +++++++++ .../scheduling_instruction_annotator_test.cc | 97 +++++++++++++++++++ 8 files changed, 271 insertions(+), 30 deletions(-) create mode 100644 xla/service/gpu/scheduling_instruction_annotator.cc create mode 100644 xla/service/gpu/scheduling_instruction_annotator.h create mode 100644 xla/service/gpu/scheduling_instruction_annotator_test.cc diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 7ea85bf6c836f2..6ccbfcdb63a1a5 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -3605,7 +3605,8 @@ void HloInstruction::PrintWithCanonicalNameMap( if (options.print_metadata() && (!metadata_->op_type().empty() || !metadata_->op_name().empty() || - !metadata_->source_file().empty())) { + !metadata_->source_file().empty() || + !metadata_->scheduling_name().empty())) { printer->Append(", metadata={"); printer->Append(xla::OpMetadataToString( *metadata_, options.print_metadata_only_op_name())); diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 1ff6ab879769b6..5eac3b82ebf1d3 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -4038,6 +4038,7 @@ cc_library( ":backend_configs_cc", ":gpu_latency_hiding_scheduler", ":gpu_schedule_postprocessing", + ":scheduling_instruction_annotator", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", @@ -4059,7 +4060,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:path", @@ -6047,6 +6047,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", @@ -6140,3 +6141,32 @@ xla_cc_test( "@tsl//tsl/platform:statusor", ], ) + +cc_library( + name = "scheduling_instruction_annotator", + srcs = ["scheduling_instruction_annotator.cc"], + hdrs = ["scheduling_instruction_annotator.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "scheduling_instruction_annotator_test", + srcs = ["scheduling_instruction_annotator_test.cc"], + deps = [ + ":scheduling_instruction_annotator", + "//xla/hlo/ir:hlo", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + ], +) diff --git a/xla/service/gpu/gpu_hlo_schedule.cc b/xla/service/gpu/gpu_hlo_schedule.cc index 4f0d3fce842223..2504b431741f82 100644 --- a/xla/service/gpu/gpu_hlo_schedule.cc +++ b/xla/service/gpu/gpu_hlo_schedule.cc @@ -33,8 +33,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -50,6 +48,7 @@ limitations under the License. #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/gpu/gpu_schedule_postprocessing.h" #include "xla/service/gpu/model/analytical_latency_estimator.h" +#include "xla/service/gpu/scheduling_instruction_annotator.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/latency_hiding_scheduler.h" @@ -507,6 +506,7 @@ absl::StatusOr ScheduleGpuModule( auto scheduler_core = std::make_unique( shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config); + pipeline.AddPass(); pipeline.AddPass( std::move(latency_estimator), std::move(async_tracker), std::move(scheduler_core), shape_size_in_bytes); diff --git a/xla/service/gpu/pipelined_p2p_rewriter.cc b/xla/service/gpu/pipelined_p2p_rewriter.cc index b8a760deec28ba..d0e841c4f9ebc1 100644 --- a/xla/service/gpu/pipelined_p2p_rewriter.cc +++ b/xla/service/gpu/pipelined_p2p_rewriter.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" @@ -414,6 +415,7 @@ bool InsertBeforeFirstCollectiveOp( } void CopyInstructionInfo(const HloInstruction* old_op, HloInstruction* new_op) { + new_op->SetAndSanitizeName(absl::StrCat(old_op->name(), ".clone")); new_op->set_metadata(old_op->metadata()); new_op->add_frontend_attributes(old_op->frontend_attributes()); new_op->CopyBackendConfigFrom(old_op); diff --git a/xla/service/gpu/pipelined_p2p_rewriter_test.cc b/xla/service/gpu/pipelined_p2p_rewriter_test.cc index e7b263eee7c867..a0d58306cfa93b 100644 --- a/xla/service/gpu/pipelined_p2p_rewriter_test.cc +++ b/xla/service/gpu/pipelined_p2p_rewriter_test.cc @@ -211,8 +211,8 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) { CHECK: %get-tuple-element = get-tuple-element(%param.1), index=1 CHECK: %get-tuple-element.1 = get-tuple-element(%param.1), index=2 CHECK: %count.1 = get-tuple-element(%param.1), index=0 - CHECK: %recv-done = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK: %recv-data = get-tuple-element(%recv-done), index=0 + CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %recv-data = get-tuple-element(%recv-done.p.clone), index=0 CHECK: %c1 = constant(1) CHECK: %new-count = add(%count.1, %c1) CHECK: %replica = replica-id() @@ -227,7 +227,7 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) { CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} CHECK: %send-data = add(%c, %s) CHECK: %after-all = after-all() - CHECK: %send-done = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} CHECK: ROOT %tuple = tuple(%new-count, %recv, %send) @@ -248,13 +248,13 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) { CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} CHECK{LITERAL}: %send.1 = send(%init, %after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} CHECK: %while-init = tuple(%c0, %recv.1, %send.1) - CHECK: %while-result = while(%while-init), condition=%while-cond, body=%while-body, + CHECK: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body, CHECK-SAME{LITERAL}: backend_config={"known_trip_count":{"n":"25"}} - CHECK: %get-tuple-element.2 = get-tuple-element(%while-result), index=1 - CHECK: %get-tuple-element.3 = get-tuple-element(%while-result), index=2 - CHECK: %recv-done.1 = recv-done(%get-tuple-element.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK: %send-done.1 = send-done(%get-tuple-element.3), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK: ROOT %entry-result = get-tuple-element(%recv-done.1), index=0 + CHECK: %get-tuple-element.2 = get-tuple-element(%while-result.p.clone), index=1 + CHECK: %get-tuple-element.3 = get-tuple-element(%while-result.p.clone), index=2 + CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: ROOT %entry-result = get-tuple-element(%recv-done.1.p.clone), index=0 CHECK: })"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -592,10 +592,10 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) { CHECK: %get-tuple-element.2 = get-tuple-element(%param.1), index=3 CHECK: %get-tuple-element.3 = get-tuple-element(%param.1), index=4 CHECK: %count.1 = get-tuple-element(%param.1), index=0 - CHECK: %recv-done = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK: %recv-data.0 = get-tuple-element(%recv-done), index=0 - CHECK: %recv-done.1 = recv-done(%get-tuple-element.2), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} - CHECK: %recv-data.1 = get-tuple-element(%recv-done.1), index=0 + CHECK: %recv-done.p.clone = recv-done(%get-tuple-element), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %recv-data.0 = get-tuple-element(%recv-done.p.clone), index=0 + CHECK: %recv-done.1.p.clone = recv-done(%get-tuple-element.2), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %recv-data.1 = get-tuple-element(%recv-done.1.p.clone), index=0 CHECK: %replica = replica-id() CHECK: %constant0 = constant(0) CHECK: %compare0 = compare(%replica, %constant0), direction=EQ @@ -614,8 +614,8 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) { CHECK: %s = dot(%c, %d), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} CHECK: %send-data = add(%c, %s) CHECK: %after-all = after-all() - CHECK: %send-done = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK: %send-done.1 = send-done(%get-tuple-element.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} CHECK: %after-all.1 = after-all() @@ -642,21 +642,21 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) { CHECK{LITERAL}: %recv.3 = recv(%after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} CHECK{LITERAL}: %send.3 = send(%init, %after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} CHECK: %while-init = tuple(%c0, %recv.2, %send.2, %recv.3, %send.3) - CHECK{LITERAL}: %while-result = while(%while-init), condition=%while-cond, body=%while-body, backend_config={"known_trip_count":{"n":"25"}} - CHECK: %get-tuple-element.4 = get-tuple-element(%while-result), index=1 - CHECK: %get-tuple-element.5 = get-tuple-element(%while-result), index=2 - CHECK: %get-tuple-element.6 = get-tuple-element(%while-result), index=3 - CHECK: %get-tuple-element.7 = get-tuple-element(%while-result), index=4 - CHECK: %recv-done.2 = recv-done(%get-tuple-element.4), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK: %recv-data.3 = get-tuple-element(%recv-done.2), index=0 - CHECK: %recv-done.3 = recv-done(%get-tuple-element.6), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} - CHECK: %recv-data.4 = get-tuple-element(%recv-done.3), index=0 + CHECK{LITERAL}: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body, backend_config={"known_trip_count":{"n":"25"}} + CHECK: %get-tuple-element.4 = get-tuple-element(%while-result.p.clone), index=1 + CHECK: %get-tuple-element.5 = get-tuple-element(%while-result.p.clone), index=2 + CHECK: %get-tuple-element.6 = get-tuple-element(%while-result.p.clone), index=3 + CHECK: %get-tuple-element.7 = get-tuple-element(%while-result.p.clone), index=4 + CHECK: %recv-done.2.p.clone = recv-done(%get-tuple-element.4), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %recv-data.3 = get-tuple-element(%recv-done.2.p.clone), index=0 + CHECK: %recv-done.3.p.clone = recv-done(%get-tuple-element.6), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %recv-data.4 = get-tuple-element(%recv-done.3.p.clone), index=0 CHECK: %replica.1 = replica-id() CHECK: %constant0.1 = constant(0) CHECK: %compare0.1 = compare(%replica.1, %constant0.1), direction=EQ CHECK: %compare.1 = broadcast(%compare0.1), dimensions={} - CHECK: %send-done.2 = send-done(%get-tuple-element.5), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK: %send-done.3 = send-done(%get-tuple-element.7), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %send-done.2.p.clone = send-done(%get-tuple-element.5), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %send-done.3.p.clone = send-done(%get-tuple-element.7), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} CHECK: ROOT %entry-result = select(%compare.1, %recv-data.3, %recv-data.4) CHECK: })"; diff --git a/xla/service/gpu/scheduling_instruction_annotator.cc b/xla/service/gpu/scheduling_instruction_annotator.cc new file mode 100644 index 00000000000000..fbf1b2c5c58eb2 --- /dev/null +++ b/xla/service/gpu/scheduling_instruction_annotator.cc @@ -0,0 +1,67 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/scheduling_instruction_annotator.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +// Populates `OpMetadata`'s `scheduling_name` field for all of the instructions +// belonging to `computation`. +absl::StatusOr AnnotateSchedulingInstructionNames( + HloComputation& computation) { + bool changed = false; + for (HloInstruction* inst : computation.instructions()) { + if (!inst->metadata().scheduling_name().empty()) { + continue; + } + inst->set_metadata_scheduling_name(std::string(inst->name())); + changed = true; + } + return changed; +} + +} // namespace + +absl::StatusOr SchedulingInstructionAnnotator::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + CHECK(module->has_schedule()) + << "The pass is supposed to run in the beginning of post-scheduling!"; + bool changed = false; + + // We visit computations in the order of callees to callers, as information is + // propagated from calles to callers. + for (HloComputation* computation : + module->MakeComputationPostOrder(execution_threads)) { + TF_ASSIGN_OR_RETURN(bool result, + AnnotateSchedulingInstructionNames(*computation)); + changed |= result; + } + + return changed; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/scheduling_instruction_annotator.h b/xla/service/gpu/scheduling_instruction_annotator.h new file mode 100644 index 00000000000000..3f9b769d3b85f0 --- /dev/null +++ b/xla/service/gpu/scheduling_instruction_annotator.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla::gpu { + +// The pass amends the `OpMetadata` with instruction name present at the +// scheduling time. This is later being used to make sure instructions are not +// renamed post scheduling. Enforcing this is necessary because otherwise +class SchedulingInstructionAnnotator : public HloModulePass { + public: + absl::string_view name() const override { + return "scheduling-instruction-annotator"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ diff --git a/xla/service/gpu/scheduling_instruction_annotator_test.cc b/xla/service/gpu/scheduling_instruction_annotator_test.cc new file mode 100644 index 00000000000000..146607f790da52 --- /dev/null +++ b/xla/service/gpu/scheduling_instruction_annotator_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/scheduling_instruction_annotator.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using SchedulingInstructionAnnotatorTest = HloTestBase; + +TEST_F(SchedulingInstructionAnnotatorTest, + AnnotatesAllInstructionsWithTheirRespectiveNames) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = f32[1] parameter(0) + p1 = f32[1] parameter(1) + add0 = f32[1] add(p0,p1) + ROOT exp0 = f32[1] exponential(add0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + SchedulingInstructionAnnotator pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + + ASSERT_TRUE(changed); + for (const auto* comp : module->computations()) { + for (const auto* instruction : comp->instructions()) { + EXPECT_EQ(instruction->name(), instruction->metadata().scheduling_name()); + } + } + constexpr absl::string_view kExpected = R"( +// CHECK: %[[P0:.+]] = {{.*}} parameter(0) +// CHECK-SAME: scheduling_name="[[P0]]" +// CHECK: %[[P1:.+]] = {{.*}} parameter(1) +// CHECK-SAME: scheduling_name="[[P1]]" +// CHECK: %[[ADD0:.+]] = {{.*}} add(%[[P0]], %[[P1]]) +// CHECK-SAME: scheduling_name="[[ADD0]]" +// CHECK: ROOT %[[EXP0:.+]] = {{.*}} exponential(%[[ADD0]]) +// CHECK-SAME: scheduling_name="[[EXP0]]" + )"; + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck( + module->ToString(HloPrintOptions().set_print_operand_shape(false)), + kExpected)); + EXPECT_TRUE(filecheck_matches); +} + +TEST_F(SchedulingInstructionAnnotatorTest, + DoesNotAnnotateAllInstructionsWithTheirRespectiveNames) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = f32[1] parameter(0), metadata={scheduling_name="p0"} + p1 = f32[1] parameter(1), metadata={scheduling_name="p1"} + add0 = f32[1] add(p0,p1), metadata={scheduling_name="add0"} + ROOT exp0 = f32[1] exponential(add0), metadata={scheduling_name="exp0"} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + SchedulingInstructionAnnotator pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla::gpu From eef7ac771b7b3a654c758c140fe7cf249d3140f8 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 26 Jul 2024 07:49:51 -0700 Subject: [PATCH 197/376] [XLA:FFI] Add API version to XLA_FFI_Api struct, and check at runtime. This exposes the FFI API version as part of the API struct and adds a check at runtime to make sure that the major versions match. Eventually we will want to include logic to check the minor version like in Pjrt. It would probably be a somewhat better user experience to error at registration time rather than call time, but we don't currently have enough metadata at registration time. It would be probably be possible to update the handler bundle provided when registering to make this work, but this was a much less invasive change, and since we already check the struct size at runtime here it seemed reasonable to include the version check in the same place. I will update with benchmark info once I get a chance to run it. PiperOrigin-RevId: 656399434 --- xla/ffi/api/api.h | 13 +++++++++++++ xla/ffi/api/c_api.h | 1 + xla/ffi/ffi_api.cc | 21 +++++++++++++++++---- xla/ffi/ffi_api.h | 4 ++++ xla/ffi/ffi_test.cc | 13 +++++++++++++ 5 files changed, 48 insertions(+), 4 deletions(-) diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index ce4b8c5dec229f..7675c3ab58f8a8 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -1297,6 +1297,19 @@ class Handler : public Ffi { call_frame->struct_size)) return err; + // Check the API versions. + auto api_version = call_frame->api->api_version; + if (api_version.major_version != XLA_FFI_API_MAJOR || + api_version.minor_version != XLA_FFI_API_MINOR) { + return InvalidArgument( + call_frame->api, + StrCat("FFI handler's API version (", XLA_FFI_API_MAJOR, ".", + XLA_FFI_API_MINOR, + ") does not match the framework's API version (", + api_version.major_version, ".", api_version.minor_version, + ")")); + } + // Check that handler is called during correct execution stage. if (XLA_FFI_PREDICT_FALSE(call_frame->stage != static_cast(stage))) { diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index bf617b087c6c77..3c9efb42eab9e9 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -529,6 +529,7 @@ struct XLA_FFI_Api { size_t struct_size; void* priv; + XLA_FFI_Api_Version api_version; XLA_FFI_InternalApi* internal_api; _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Create); diff --git a/xla/ffi/ffi_api.cc b/xla/ffi/ffi_api.cc index c750cf72a9752e..f402ed24b32dcd 100644 --- a/xla/ffi/ffi_api.cc +++ b/xla/ffi/ffi_api.cc @@ -97,11 +97,12 @@ absl::Status TakeStatus(XLA_FFI_Error* error) { return status; } -absl::Status Call(Ffi& handler, CallFrame& call_frame, - const CallOptions& options, ExecutionStage stage) { +absl::Status CallWithApi(const XLA_FFI_Api* api, Ffi& handler, + CallFrame& call_frame, const CallOptions& options, + ExecutionStage stage) { XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options); - XLA_FFI_CallFrame ffi_call_frame = call_frame.Build( - GetXlaFfiApi(), &ctx, static_cast(stage)); + XLA_FFI_CallFrame ffi_call_frame = + call_frame.Build(api, &ctx, static_cast(stage)); XLA_FFI_Error* status = nullptr; try { status = handler.Call(&ffi_call_frame); @@ -111,6 +112,11 @@ absl::Status Call(Ffi& handler, CallFrame& call_frame, return TakeStatus(status); } +absl::Status Call(Ffi& handler, CallFrame& call_frame, + const CallOptions& options, ExecutionStage stage) { + return CallWithApi(GetXlaFfiApi(), handler, call_frame, options, stage); +} + absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, const CallOptions& options, XLA_FFI_ExecutionStage stage) { XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options); @@ -551,6 +557,13 @@ static XLA_FFI_Api api = { XLA_FFI_Api_STRUCT_SIZE, /*priv=*/nullptr, + XLA_FFI_Api_Version{ + XLA_FFI_Api_Version_STRUCT_SIZE, + /*priv=*/nullptr, + XLA_FFI_API_MAJOR, + XLA_FFI_API_MINOR, + }, + &internal_api, XLA_FFI_Error_Create, diff --git a/xla/ffi/ffi_api.h b/xla/ffi/ffi_api.h index 8b2242d4297238..7a6e5aa3df9506 100644 --- a/xla/ffi/ffi_api.h +++ b/xla/ffi/ffi_api.h @@ -62,6 +62,10 @@ struct CallOptions { // `error` if it's not nullptr; returns OK status otherwise. absl::Status TakeStatus(XLA_FFI_Error* error); +absl::Status CallWithApi(const XLA_FFI_Api* api, Ffi& handler, + CallFrame& call_frame, const CallOptions& options = {}, + ExecutionStage stage = ExecutionStage::kExecute); + absl::Status Call(Ffi& handler, CallFrame& call_frame, const CallOptions& options = {}, ExecutionStage stage = ExecutionStage::kExecute); diff --git a/xla/ffi/ffi_test.cc b/xla/ffi/ffi_test.cc index 044f526f64392d..63f5dbf30e20d2 100644 --- a/xla/ffi/ffi_test.cc +++ b/xla/ffi/ffi_test.cc @@ -843,6 +843,19 @@ TEST(FfiTest, AllowRegisterDuplicateWhenEqual) { TF_ASSERT_OK(status); } +TEST(FfiTest, ApiVersion) { + auto handler = Ffi::Bind().To([]() { return absl::OkStatus(); }); + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + auto call_frame = builder.Build(); + auto api = GetXlaFfiApi(); + XLA_FFI_Api api_copy = *api; + api_copy.api_version.major_version += 1; + auto status = CallWithApi(&api_copy, *handler, call_frame); + EXPECT_TRUE(absl::StrContains(status.message(), "FFI handler's API version")) + << "status.message():\n" + << status.message() << "\n"; +} + //===----------------------------------------------------------------------===// // Performance benchmarks are below. //===----------------------------------------------------------------------===// From 127bac3ea7aa36069b12d689e0f8f9382a20b212 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Fri, 26 Jul 2024 08:30:23 -0700 Subject: [PATCH 198/376] [XLA:GPU] Add HLO-based pipeline parallelism test for #microbatches > #devices PiperOrigin-RevId: 656411214 --- .../collective_pipeline_parallelism_test.cc | 228 +++++++++++++++--- 1 file changed, 192 insertions(+), 36 deletions(-) diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index abf0e4739fa3ca..627c652b78cf49 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -114,7 +114,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest, inputs_a.push_back(LiteralUtil::CreateR2({{val, val}, {val, val}})); } Literal input_b_replicated = LiteralUtil::CreateR2({{0, 0}, {0, 1}}); - std::vector> inputs; + std::vector> inputs; for (int64_t i = 0; i < kNumReplicas; ++i) { inputs.push_back({&inputs_a[i], &input_b_replicated}); } @@ -128,6 +128,31 @@ XLA_TEST_F(CollectivePipelineParallelismTest, LiteralTestUtil::ExpectR2Equal({{0, 0}, {1, 1}}, results[3]); } +// Helper functions for pipeline parallelism tests where each stage scales the +// input by some factor. +absl::StatusOr CreateLinearLayerWeights(int64_t size, float factor) { + return LiteralUtil::CreateLiteralWithGenerator( + ShapeUtil::MakeShape(F32, {size, size}), + [&](absl::Span idx) -> float { + return idx[0] == idx[1] ? factor : 0.0; + }); +}; +absl::StatusOr CreateZeroInputR2(int64_t microbatches, int64_t size) { + return LiteralUtil::CreateLiteralWithGenerator( + ShapeUtil::MakeShape(F32, {microbatches, size}), + [&](absl::Span idx) -> float { return 0.0; }); +}; +absl::StatusOr CreateFingerprintInput(int64_t microbatches, + int64_t size, + float factor = 1.0) { + return LiteralUtil::CreateLiteralWithGenerator( + ShapeUtil::MakeShape(F32, {microbatches, size}), + [&](absl::Span idx) -> float { + float fingerprint = 1.0 * idx[0] + 0.0001 * idx[1]; + return factor * fingerprint; + }); +}; + // Naive implementation of pipeline parallelism: // - 4 devices // - 4 microbatches @@ -236,39 +261,28 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { // We assign the weights to the replicas such that the layers scale the input // data by 1.0, 2.0, 3.0 and 4.0. The combined effect is to scale the input // data by 24.0. - auto generate_scale_weights = [&](float factor) -> absl::StatusOr { - return LiteralUtil::CreateLiteralWithGenerator( - ShapeUtil::MakeShape(F32, {16, 16}), - [&](absl::Span idx) -> float { - return idx[0] == idx[1] ? factor : 0.0; - }); - }; - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r0, generate_scale_weights(1.0)); - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r1, generate_scale_weights(2.0)); - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r2, generate_scale_weights(3.0)); - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r3, generate_scale_weights(4.0)); + const int64_t kInputSize = 16; + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r0, + CreateLinearLayerWeights(kInputSize, 1.0)); + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r1, + CreateLinearLayerWeights(kInputSize, 2.0)); + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r2, + CreateLinearLayerWeights(kInputSize, 3.0)); + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r3, + CreateLinearLayerWeights(kInputSize, 4.0)); // Only the first replica holds the input to the pipeline in this naive // implementation. The remaining replicas get zero/dummy input. - auto generate_zero_input = [&]() -> absl::StatusOr { - return LiteralUtil::CreateLiteralWithGenerator( - ShapeUtil::MakeShape(F32, {4, 16}), - [&](absl::Span idx) -> float { return 0.0; }); - }; - auto generate_fingerprint_input = [&]() -> absl::StatusOr { - return LiteralUtil::CreateLiteralWithGenerator( - ShapeUtil::MakeShape(F32, {4, 16}), - [&](absl::Span idx) -> float { - return 1.0 * idx[0] + 0.0001 * idx[1]; - }); - }; - TF_ASSERT_OK_AND_ASSIGN(Literal real_input, generate_fingerprint_input()); - TF_ASSERT_OK_AND_ASSIGN(Literal fake_input, generate_zero_input()); - - std::vector> args = {{&weights_r0, &real_input}, - {&weights_r1, &fake_input}, - {&weights_r2, &fake_input}, - {&weights_r3, &fake_input}}; + const int64_t kMicrobatches = 4; + TF_ASSERT_OK_AND_ASSIGN(Literal real_input, + CreateFingerprintInput(kMicrobatches, kInputSize)); + TF_ASSERT_OK_AND_ASSIGN(Literal fake_input, + CreateZeroInputR2(kMicrobatches, kInputSize)); + + std::vector> args = {{&weights_r0, &real_input}, + {&weights_r1, &fake_input}, + {&weights_r2, &fake_input}, + {&weights_r3, &fake_input}}; TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), args, kNumReplicas, @@ -276,13 +290,155 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { // Check pipeline output for last replica. // The combined effect of the pipeline is to scale the input data by 24.0. + const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0; TF_ASSERT_OK_AND_ASSIGN( Literal expected_output, - (LiteralUtil::CreateLiteralWithGenerator( - ShapeUtil::MakeShape(F32, {4, 16}), - [&](absl::Span multi_index) -> float { - return real_input.Get(multi_index) * 1.0 * 2.0 * 3.0 * 4.0; - }))); + CreateFingerprintInput(kMicrobatches, kInputSize, kExpectedFactor)); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], + ErrorSpec{1e-5, 1e-5})); +} + +// Naive implementation of pipeline parallelism: +// - 4 devices +// - 5 microbatches +// - no circular repeat +// - no disabled collectives +// - no collective pipelining +// +// Every stage of the pipeline is a single linear layer. +XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) { + const absl::string_view kModuleStr = R"( + HloModule test + + get_circ_buffer_index { + offset = u32[] parameter(0) + index = u32[] parameter(1) + size = u32[] parameter(2) + t0 = u32[] add(offset, index) + t1 = u32[] divide(t0, size) + t2 = u32[] multiply(t1, size) + ROOT t4 = u32[] subtract(t0, t2) + } + + is_input_replica { + replica_id = u32[] replica-id() + c0 = u32[] constant(0) + ROOT predicate = pred[] compare(replica_id, c0), direction=EQ + } + + is_output_replica { + replica_id = u32[] replica-id() + c1 = u32[] constant(1) + ROOT predicate = pred[] compare(replica_id, c1), direction=EQ + } + + while_condition { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) parameter(0) + i = u32[] get-tuple-element(tuple), index=4 + n = u32[] constant(8) + ROOT predicate = pred[] compare(i, n), direction=LT + } + + while_body { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) parameter(0) + weights = f32[16,16] get-tuple-element(tuple), index=0 + input = f32[5,16] get-tuple-element(tuple), index=1 + output = f32[5,16] get-tuple-element(tuple), index=2 + tmp = f32[16] get-tuple-element(tuple), index=3 + i = u32[] get-tuple-element(tuple), index=4 + + c1 = u32[] constant(1) + c2 = u32[] constant(2) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + + input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index + input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), + dynamic_slice_sizes={1,16} + input_slice_ = f32[16] reshape(input_slice) + + prev_stage_slice = f32[16] collective-permute(tmp), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + + read_input = pred[] call(), to_apply=is_input_replica + compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice) + + compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + output_index = u32[] call(c2, i, c5), to_apply=get_circ_buffer_index + output_slice = f32[1,16] reshape(compute_out) + output_ = f32[5,16] dynamic-update-slice(output, output_slice, output_index, + c0) + + i_ = add(i, c1) + + ROOT tuple1 = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output_, compute_out, i_) + } + + ENTRY main { + weights = f32[16,16] parameter(0) + input = f32[5,16] parameter(1) + + cf0 = f32[] constant(0) + output = f32[5,16] broadcast(cf0), dimensions={} + tmp = f32[16] broadcast(cf0), dimensions={} + c0 = u32[] constant(0) + + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output, tmp, c0) + tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) while(tuple), + condition=while_condition, body=while_body + + ROOT output_ = f32[5,16] get-tuple-element(tuple_), index=2 + } + )"; + + const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // This pipeline consists of 4 layers, each of which is a single linear layer. + // We assign the weights to the replicas such that the layers scale the input + // data by 1.0, 2.0, 3.0 and 4.0. The combined effect is to scale the input + // data by 24.0. + const int64_t kInputSize = 16; + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r0, + CreateLinearLayerWeights(kInputSize, 1.0)); + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r1, + CreateLinearLayerWeights(kInputSize, 2.0)); + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r2, + CreateLinearLayerWeights(kInputSize, 3.0)); + TF_ASSERT_OK_AND_ASSIGN(Literal weights_r3, + CreateLinearLayerWeights(kInputSize, 4.0)); + + // Only the first replica holds the input to the pipeline in this naive + // implementation. The remaining replicas get zero/dummy input. + const int64_t kMicrobatches = 5; + TF_ASSERT_OK_AND_ASSIGN(Literal real_input, + CreateFingerprintInput(kMicrobatches, kInputSize)); + TF_ASSERT_OK_AND_ASSIGN(Literal fake_input, + CreateZeroInputR2(kMicrobatches, kInputSize)); + + // Check pipeline output for last replica. + // The combined effect of the pipeline is to scale the input data by 24.0. + const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0; + TF_ASSERT_OK_AND_ASSIGN( + Literal expected_output, + CreateFingerprintInput(kMicrobatches, kInputSize, kExpectedFactor)); + std::vector> args = {{&weights_r0, &real_input}, + {&weights_r1, &fake_input}, + {&weights_r2, &fake_input}, + {&weights_r3, &fake_input}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), args, kNumReplicas, + /*run_hlo_passes=*/true)); EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], ErrorSpec{1e-5, 1e-5})); } From 6f81609a96fce4f14d66f0cf96f0eed3e6107935 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Jul 2024 09:33:26 -0700 Subject: [PATCH 199/376] [XLA][HostOffloader] Remove redundant copies to and from host for host offloaded computation outputs The simple algorithm tracks usages of all outputs of each host offloaded computation. For each: - If they are ONLY used on the host and they are outputs of the entry computation, it sets the memory space to Host. - If they are ONLY used on the host, but are temporaries, no changes are made. - For cases replaced, if a MoveToHost is found (NOTE: that the algorithm does not explicitly check that any exist nor that all paths lead to a MoveToHost) for an output that is only used on the host, we simply replace the usage. PiperOrigin-RevId: 656430411 --- xla/service/BUILD | 2 + xla/service/host_offloader.cc | 214 +++++++++++++++++++- xla/service/host_offloader.h | 19 +- xla/service/host_offloader_test.cc | 309 +++++++++++++++++++++++++++++ 4 files changed, 541 insertions(+), 3 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 09676869ddab91..cc03a42fb2f019 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -6414,6 +6414,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -6442,6 +6443,7 @@ xla_cc_test( "//xla:util", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", diff --git a/xla/service/host_offloader.cc b/xla/service/host_offloader.cc index 95c97e94c704da..d2ab69ee684d7a 100644 --- a/xla/service/host_offloader.cc +++ b/xla/service/host_offloader.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/host_offloader.h" #include +#include #include #include #include @@ -26,15 +27,19 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" #include "xla/service/call_graph.h" @@ -975,17 +980,224 @@ absl::StatusOr HostOffloader::ApplySchedulingFix( return changed; } +namespace { + +bool IsHostAsyncStart(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kAsyncStart && + instruction->async_execution_thread() == HloInstruction::kHostThread; +} + +absl::Status ValidateAsyncComputationStructure(HloComputation* computation) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kParameter || instr->IsRoot()) { + continue; + } + + return absl::InternalError( + absl::StrCat("Unexpected instruction found in async computation: ", + instr->ToString())); + } + + return absl::OkStatus(); +} + +// Updates memory space for all outputs of the host offloaded computation +// (associated with `call_start`) that are ONLY used on host. NOTE: We also +// remove redundant copies to host, if any. +absl::StatusOr UpdateMemorySpaceForHostOffloadedOutputs( + HloInstruction* call_start, + absl::flat_hash_map>& + host_instr) { + // Keep track of MoveToHost instructions that need to be removed. + std::vector to_replace; + + HloComputation* called_computation = call_start->async_wrapped_computation(); + TF_RETURN_IF_ERROR(ValidateAsyncComputationStructure(called_computation)); + HloInstruction* root = called_computation->root_instruction(); + Shape* root_shape = root->mutable_shape(); + + for (auto& pair : host_instr) { + std::vector& instruction_and_shape_indexes = + pair.second; + + for (InstructionAndShapeIndex& instr_and_shape : + instruction_and_shape_indexes) { + // If instruction is MoveToHost, we will replace usage. + if (instr_and_shape.instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + to_replace.emplace_back(instr_and_shape); + continue; + } + + SetMemorySpace(ShapeUtil::GetMutableSubshape( + instr_and_shape.instruction->mutable_shape(), + instr_and_shape.shape_index), + Layout::kHostMemorySpace); + } + + // Update the memory space for the output of the computation call itself. + size_t index = pair.first; + SetMemorySpace(root_shape->mutable_tuple_shapes(index), + Layout::kHostMemorySpace); + } + + // Remove MoveToHost usage. + for (InstructionAndShapeIndex& instr_and_shape : to_replace) { + HloInstruction* pred = instr_and_shape.instruction->mutable_operand(0); + TF_RETURN_IF_ERROR(instr_and_shape.instruction->ReplaceAllUsesWith(pred)); + } + + return !host_instr.empty(); +} + +constexpr int64_t kShapeTupleOutputIndexInAsyncStart = 1; + +// Additional checks (does not run IsValidDuringPureMemoryOffload) to determine +// if the respective tensor can be on host. +bool ExtraCheckForValidUsageOnHostForHostOffloadedOutputs( + const Shape& entry_computation_shape, + InstructionAndShapeIndex& instruction_and_shape_index) { + HloInstruction* instruction = instruction_and_shape_index.instruction; + ShapeIndex& shape_index = instruction_and_shape_index.shape_index; + + // We respect entry computation layout. So for the cases where the + // outputs are not expected on host, we bail. + if (instruction->IsRoot() && instruction->parent()->IsEntryComputation()) { + if (ShapeUtil::GetSubshape(entry_computation_shape, shape_index) + .layout() + .memory_space() != Layout::kHostMemorySpace) { + return false; + } + } + + // For custom calls, we conservatively only accept MoveToHost. + // For MoveToDevice, this could be re-considered, or done as part of a + // generic redundant copies removal. + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() != + host_memory_offload_annotations::kMoveToHostCustomCallTarget) { + return false; + } + + // TODO(b/347101407): To also consider host async computations, as we + // extend GetSuccessors to properly treat it. + if (instruction->opcode() == HloOpcode::kAsyncStart || + instruction->opcode() == HloOpcode::kAsyncDone) { + return false; + } + + return true; +} + +} // namespace + +absl::StatusOr HostOffloader::HandleRedundantCopiesBackToHost( + const HloModule* module, HloInstruction* instruction) { + HloAsyncInstruction* call_start = Cast(instruction); + + CHECK_EQ(call_start->users().size(), 1); + HloInstruction* call_done = call_start->users()[0]; + + absl::flat_hash_map> + host_instrs; + const Shape& entry_computation_shape = + module->entry_computation_layout().result_layout().shape(); + + // We collect all usages per output index, stopping at any non host + // instruction. + const Shape& done_shape = call_done->shape(); + for (size_t index = 0; index < done_shape.tuple_shapes_size(); index++) { + ShapeIndex output_shape_index = {static_cast(index)}; + std::queue queue; + queue.push(InstructionAndShapeIndex(call_done, output_shape_index)); + + // async-start packs the (inputs, outputs, context) in a tuple. + ShapeIndex start_shape_index = {kShapeTupleOutputIndexInAsyncStart, + static_cast(index)}; + + // TODO(b/347101407): Start from async-start and trace through the + // computation as well in GetSuccessors instead of having to manually add + // async-done and update the async computation separately. + host_instrs[index].push_back( + InstructionAndShapeIndex(call_start, start_shape_index)); + host_instrs[index].push_back( + InstructionAndShapeIndex(call_done, output_shape_index)); + + bool host_only = true; + // Keep track if the output of the host offloading computation is also an + // output of the entry computation. Temporaries are conservatively kept on + // HBM. + // + // TODO(b/347101407): Better use AliasAnalysis here to trace host compute + // outputs to entry compute outputs instead. NOTE: The current algorithm + // only tracks accepted host offloading operations which operate on the same + // tensor. + bool entry_compute_output = false; + + while (!queue.empty() && host_only) { + InstructionAndShapeIndex instruction_and_shape_index = queue.front(); + queue.pop(); + + TF_ASSIGN_OR_RETURN(std::vector successors, + GetSuccessors(InstructionAndShapeIndex( + instruction_and_shape_index.instruction, + instruction_and_shape_index.shape_index))); + + // Check if any of the successors needs to be on device. + for (InstructionAndShapeIndex& successor : successors) { + if (!IsValidDuringPureMemoryOffload(successor.instruction) || + !ExtraCheckForValidUsageOnHostForHostOffloadedOutputs( + entry_computation_shape, successor)) { + host_only = false; + break; + } + + if (successor.instruction->IsRoot() && + successor.instruction->parent()->IsEntryComputation()) { + entry_compute_output = true; + } + + queue.push(successor); + host_instrs[index].emplace_back(successor); + } + } + + if (!host_only || !entry_compute_output) { + host_instrs.erase(index); + } + } + + // Update memory space for the host_offloading outputs that never get used on + // device. + return UpdateMemorySpaceForHostOffloadedOutputs(call_start, host_instrs); +} + absl::StatusOr HostOffloader::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; + + // First remove redundant copies to and from host (conservatively) starting + // from the outputs of the host offloaded computations. Iterate over all + // instructions and look for XLA host offload annotations. + bool changed_in_loop; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instruction : computation->instructions()) { + if (IsHostAsyncStart(instruction)) { + TF_ASSIGN_OR_RETURN(changed_in_loop, HandleRedundantCopiesBackToHost( + module, instruction)); + changed = changed || changed_in_loop; + } + } + } + TF_ASSIGN_OR_RETURN(const bool input_streaming_changed_module, HandleInputStreaming(module->entry_computation())); changed = changed || input_streaming_changed_module; // Since we're modifying the graph as we iterate over it, any time we change // it, we need to re-run the loop. - bool changed_in_loop; do { changed_in_loop = false; for (HloComputation* computation : diff --git a/xla/service/host_offloader.h b/xla/service/host_offloader.h index 880cda3d77b621..994c40fb62bc7c 100644 --- a/xla/service/host_offloader.h +++ b/xla/service/host_offloader.h @@ -67,8 +67,17 @@ bool operator==(const InstructionAndShapeIndex& lhs, // tensors along each path have their memory space set as host memory space. If // a MoveToHost custom call is paired with a DynamicUpdateSlice, the // DynamicUpdateSlice will write into host memory space. Otherwise, a copy from -// device to host will be inserted. All MoveToHost and MoveToDevice custom calls -// are removed by the end of this pass. +// device to host will be inserted. +// +// If an output of a host offloaded computation is only used on host, the memory +// space of the usages are updated to reflect it and no copies to and from host +// are performed. Any MoveToHost instructions for outputs used only on host, are +// removed. +// TODO(b/347101407): A better approach could be to remove redundant copies in a +// generalized fashion. Should also be moved out of Host Offloader. +// +// All MoveToHost and MoveToDevice custom calls are removed by the end of this +// pass. class HostOffloader : public HloModulePass { public: explicit HostOffloader(int64_t host_memory_space_color) @@ -167,6 +176,12 @@ class HostOffloader : public HloModulePass { absl::StatusOr ApplySchedulingFix( HloModule* module, const absl::flat_hash_set& execution_threads); + + // Starting from the outputs of the host offloaded computation, track all + // their usages. For the outputs that are ONLY used on host, remove redundant + // copies to and from host, as well as update the memory space. + absl::StatusOr HandleRedundantCopiesBackToHost( + const HloModule* module, HloInstruction* instruction); }; } // namespace xla diff --git a/xla/service/host_offloader_test.cc b/xla/service/host_offloader_test.cc index 85cc7742b3ce45..46e8aa003d1666 100644 --- a/xla/service/host_offloader_test.cc +++ b/xla/service/host_offloader_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -37,6 +38,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" #include "xla/util.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -49,6 +51,7 @@ namespace { class HostOffloaderTest : public HloTestBase { protected: static constexpr int64_t kHostMemorySpaceColor{5}; + static constexpr int64_t kHbmMemorySpaceColor{0}; absl::StatusOr RunHostOffloader(HloModule* module, bool after_layout = false) { @@ -3364,6 +3367,312 @@ ENTRY main { EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); } +TEST_F(HostOffloaderTest, BasicAsyncHostOffloadedCall_RemoveRedundantCopies) { + const std::string& hlo_string = R"( +HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(5)}, f32[4096]{0:S(5)})} + +%async_computation { + %param_0 = f32[4096] parameter(0) + ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" +}, execution_thread="host" + +ENTRY %main { + %a = f32[4096] parameter(0) + %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%a), async_execution_thread="host", calls=%async_computation + %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) + %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 + %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 + %gte_0_host = f32[4096] custom-call(%gte_0), custom_call_target="MoveToHost" + %gte_1_host = f32[4096] custom-call(%gte_1), custom_call_target="MoveToHost" + ROOT %tuple = (f32[4096], f32[4096]) tuple(%gte_0_host, %gte_1_host) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* async_start = FindInstruction(module.get(), "async-start"); + ASSERT_NE(async_start, nullptr); + HloInstruction* async_done = FindInstruction(module.get(), "async-done"); + ASSERT_NE(async_done, nullptr); + + HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); + ASSERT_NE(gte_0, nullptr); + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); + ASSERT_NE(gte_1, nullptr); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + + HloInstruction* gte_0_host = FindInstruction(module.get(), "gte_0_host"); + ASSERT_EQ(gte_0_host, nullptr); + HloInstruction* gte_1_host = FindInstruction(module.get(), "gte_1_host"); + ASSERT_EQ(gte_1_host, nullptr); + + // Check all set of successors. + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); + ASSERT_NE(tuple, nullptr); + std::vector expected = {gte_0, gte_1}; + EXPECT_THAT(tuple->operands(), + ::testing::UnorderedElementsAreArray(expected)); +} + +TEST_F(HostOffloaderTest, + BasicAsyncHostOffloadedCall_NoChangesWhenEntryLayoutExpectsHBM) { + const std::string& hlo_string = R"( +HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(0)}, f32[4096]{0:S(0)})} + +%async_computation { + %param_0 = f32[4096] parameter(0) + ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" +}, execution_thread="host" + +ENTRY %main { + %a = f32[4096] parameter(0) + %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%a), async_execution_thread="host", calls=%async_computation + %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) + %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 + %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 + ROOT %tuple = (f32[4096], f32[4096]) tuple(%gte_0, %gte_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK(RunHostOffloader(module.get())); + + HloInstruction* async_start = FindInstruction(module.get(), "async-start"); + ASSERT_NE(async_start, nullptr); + HloInstruction* async_done = FindInstruction(module.get(), "async-done"); + ASSERT_NE(async_done, nullptr); + + HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); + ASSERT_NE(gte_0, nullptr); + TestShapeHasMemorySpace(gte_0->shape(), kHbmMemorySpaceColor); + HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); + ASSERT_NE(gte_1, nullptr); + TestShapeHasMemorySpace(gte_1->shape(), kHbmMemorySpaceColor); +} + +TEST_F(HostOffloaderTest, + BasicAsyncHostOffloadedCall_RemoveOnlyRedundantCopies) { + const std::string& hlo_string = R"( +HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(5)}, f32[4096]{0:S(5)})} + +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add_res = f32[] add(%lhs, %rhs) +} + +%async_computation { + %param_0 = f32[4096] parameter(0) + ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" +}, execution_thread="host" + +ENTRY %main { + %a = f32[4096] parameter(0) + %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%a), async_execution_thread="host", calls=%async_computation + %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) + %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 + %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 + %sum = f32[4096] add(%gte_0, %gte_0) + %gte_0_host = f32[4096] custom-call(%gte_0), custom_call_target="MoveToHost" + %gte_1_host = f32[4096] custom-call(%gte_1), custom_call_target="MoveToHost" + ROOT %tuple = (f32[4096]{0:S(5)}, f32[4096]) tuple(%gte_0_host, %gte_1_host) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* async_start = FindInstruction(module.get(), "async-start"); + ASSERT_NE(async_start, nullptr); + HloInstruction* async_done = FindInstruction(module.get(), "async-done"); + ASSERT_NE(async_done, nullptr); + + HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); + ASSERT_NE(gte_0, nullptr); + TestShapeHasMemorySpace(gte_0->shape(), kHbmMemorySpaceColor); + HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); + ASSERT_NE(gte_1, nullptr); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + + // Since gte_0 is used on device (we do not take dead code into account here + // ..) gte_0 will be copied to device and be moved to host. + HloInstruction* gte_0_host = FindInstruction(module.get(), "gte_0_host"); + ASSERT_EQ(gte_0_host, nullptr); // replaced with copy + HloInstruction* copy = FindInstruction(module.get(), "copy"); + ASSERT_NE(copy, nullptr); + EXPECT_EQ(copy->operands()[0], gte_0); + + HloInstruction* gte_1_host = FindInstruction(module.get(), "gte_1_host"); + ASSERT_EQ(gte_1_host, nullptr); +} + +TEST_F(HostOffloaderTest, + AsyncHostOffloadedCall_nonEntryPoint_RemoveRedundantCopies) { + const std::string& hlo_string = R"( +HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(5)}, f32[4096]{0:S(5)})} + +%async_computation { + %param_0 = f32[4096] parameter(0) + ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" +}, execution_thread="host" + +%non_async_computation { + %param_0 = f32[4096] parameter(0) + %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%param_0), async_execution_thread="host", calls=%async_computation + %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) + %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 + %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 + %gte_0_host = f32[4096] custom-call(%gte_0), custom_call_target="MoveToHost" + %gte_1_host = f32[4096] custom-call(%gte_1), custom_call_target="MoveToHost" + ROOT %tuple_non_async = (f32[4096]{0:S(5)}, f32[4096]) tuple(%gte_0_host, %gte_1_host) +} + +ENTRY %main { + %a = f32[4096] parameter(0) + %call = (f32[4096], f32[4096]) call(%a), to_apply=%non_async_computation + %call_0 = f32[4096] get-tuple-element(%call), index=0 + %call_1 = f32[4096] get-tuple-element(%call), index=1 + ROOT %tuple = (f32[4096], f32[4096]) tuple(%call_0, %call_1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* async_start = FindInstruction(module.get(), "async-start"); + ASSERT_NE(async_start, nullptr); + HloInstruction* async_done = FindInstruction(module.get(), "async-done"); + ASSERT_NE(async_done, nullptr); + + HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); + ASSERT_NE(gte_0, nullptr); + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); + ASSERT_NE(gte_1, nullptr); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + + HloInstruction* gte_0_host = FindInstruction(module.get(), "gte_0_host"); + ASSERT_EQ(gte_0_host, nullptr); + HloInstruction* gte_1_host = FindInstruction(module.get(), "gte_1_host"); + ASSERT_EQ(gte_1_host, nullptr); + + HloInstruction* tuple_non_async = + FindInstruction(module.get(), "tuple_non_async"); + ASSERT_NE(tuple_non_async, nullptr); + std::vector expected = {gte_0, gte_1}; + EXPECT_THAT(tuple_non_async->operands(), + ::testing::UnorderedElementsAreArray(expected)); + + // Check the main output is on host. + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); + ASSERT_NE(tuple, nullptr); + TestShapeHasMemorySpace(tuple->shape().tuple_shapes(0), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(tuple->shape().tuple_shapes(1), + kHostMemorySpaceColor); +} + +TEST_F(HostOffloaderTest, + AsyncHostOffloadedCall_passedToCall_RemoveRedundantCopies) { + const std::string& hlo_string = R"( +HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(5)}, f32[4096]{0:S(5)})} + +%async_computation { + %param_0 = f32[4096] parameter(0) + ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" +}, execution_thread="host" + +%non_async_computation { + %param_0_non_async = f32[4096] parameter(0) + %param_1_non_async = f32[4096] parameter(1) + ROOT %tuple_non_async = (f32[4096], f32[4096]) tuple(%param_0_non_async, %param_1_non_async) +} + +ENTRY %main { + %a = f32[4096] parameter(0) + %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%a), async_execution_thread="host", calls=%async_computation + %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) + %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 + %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 + %call = (f32[4096], f32[4096]) call(%gte_0, %gte_1), to_apply=%non_async_computation + %call_0 = f32[4096] get-tuple-element(%call), index=0 + %call_1 = f32[4096] get-tuple-element(%call), index=1 + %call_0_host = f32[4096] custom-call(%call_0), custom_call_target="MoveToHost" + %call_1_host = f32[4096] custom-call(%call_1), custom_call_target="MoveToHost" + ROOT %tuple = (f32[4096], f32[4096]) tuple(%call_0_host, %call_1_host) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* async_start = FindInstruction(module.get(), "async-start"); + ASSERT_NE(async_start, nullptr); + HloInstruction* async_done = FindInstruction(module.get(), "async-done"); + ASSERT_NE(async_done, nullptr); + + HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); + ASSERT_NE(gte_0, nullptr); + TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); + HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); + ASSERT_NE(gte_1, nullptr); + TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); + + HloInstruction* call_0 = FindInstruction(module.get(), "call_0"); + ASSERT_NE(call_0, nullptr); + HloInstruction* call_1 = FindInstruction(module.get(), "call_1"); + ASSERT_NE(call_1, nullptr); + + HloInstruction* call_0_host = FindInstruction(module.get(), "call_0_host"); + ASSERT_EQ(call_0_host, nullptr); + HloInstruction* call_1_host = FindInstruction(module.get(), "call_1_host"); + ASSERT_EQ(call_1_host, nullptr); + + HloInstruction* param_0_non_async = + FindInstruction(module.get(), "param_0_non_async"); + ASSERT_NE(param_0_non_async, nullptr); + TestShapeHasMemorySpace(param_0_non_async->shape(), kHostMemorySpaceColor); + HloInstruction* param_1_non_async = + FindInstruction(module.get(), "param_1_non_async"); + ASSERT_NE(param_1_non_async, nullptr); + TestShapeHasMemorySpace(param_1_non_async->shape(), kHostMemorySpaceColor); + + HloInstruction* tuple_non_async = + FindInstruction(module.get(), "tuple_non_async"); + ASSERT_NE(tuple_non_async, nullptr); + std::vector expected_operands = {param_0_non_async, + param_1_non_async}; + EXPECT_THAT(tuple_non_async->operands(), + ::testing::UnorderedElementsAreArray(expected_operands)); + + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); + ASSERT_NE(tuple, nullptr); + TestShapeHasMemorySpace(tuple->shape().tuple_shapes(0), + kHostMemorySpaceColor); + TestShapeHasMemorySpace(tuple->shape().tuple_shapes(1), + kHostMemorySpaceColor); + + std::vector expected = {call_0, call_1}; + EXPECT_THAT(tuple->operands(), + ::testing::UnorderedElementsAreArray(expected)); +} + } // namespace } // namespace xla From 3774418c07f3a034d82e4fd10bf905e15d32525c Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Fri, 26 Jul 2024 09:40:40 -0700 Subject: [PATCH 200/376] [XLA:GPU] Move test helper functions to literal utils. Move helper functions to literal utils: CreateFull, MakeScalarMatrixR2, CreateFingerprintMatixR2. PiperOrigin-RevId: 656432716 --- xla/literal_util.h | 50 +++++++++++- .../collective_pipeline_parallelism_test.cc | 77 +++++-------------- 2 files changed, 66 insertions(+), 61 deletions(-) diff --git a/xla/literal_util.h b/xla/literal_util.h index 24cccb58438c17..f6e5f581802480 100644 --- a/xla/literal_util.h +++ b/xla/literal_util.h @@ -128,6 +128,9 @@ class LiteralUtil { template static Literal CreateFullWithDescendingLayout( absl::Span dimensions, NativeT value); + template + static Literal CreateFull(absl::Span dimensions, + NativeT value); // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear @@ -175,10 +178,20 @@ class LiteralUtil { std::initializer_list> values, int64_t projection_p, int64_t projection_z); - // Returns an identity matrix (rank 2) with the given row and column count. + // Returns a scalar matrix (rank 2) of the given size and scalar value. + template + static Literal MakeScalarMatrixR2(int64_t size, NativeT scalar); + + // Returns an identity matrix (rank 2) of the given size. template static Literal MakeIdentityR2(int64_t size); + // Creates fingerprint input where each entry encodes its row and column + // scaled by the given scale. + template + static Literal CreateFingerprintMatixR2(int64_t m, int64_t n, + NativeT scale = 1); + // Returns a tuple literal composed of given literals. Data is copied from the // given elements into the returned literal. static Literal MakeTuple(absl::Span elements); @@ -516,12 +529,32 @@ template return CreateFromArrayWithLayout(values, layout); } -// Returns an identity matrix (rank 2) with the given row and column count. +// Creates a squared scalar matrix of given size. template -/* static */ Literal LiteralUtil::MakeIdentityR2(int64_t size) { +/* static */ Literal LiteralUtil::MakeScalarMatrixR2(int64_t size, + NativeT scalar) { Array2D array(size, size, 0); for (int64_t i = 0; i < size; ++i) { - array(i, i) = 1; + array(i, i) = scalar; + } + return CreateR2FromArray2D(array); +} + +template +/* static */ Literal LiteralUtil::MakeIdentityR2(int64_t size) { + return MakeScalarMatrixR2(size, 1); +} + +template +/* static */ Literal LiteralUtil::CreateFingerprintMatixR2(int64_t m, int64_t n, + NativeT scale) { + NativeT row_factor = log10(m) + 1; + NativeT col_factor = log10(n) + 1; + Array2D array(m, n, 0); + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + array(i, i) = scale * (row_factor * i + col_factor * j); + } } return CreateR2FromArray2D(array); } @@ -535,6 +568,15 @@ template return literal; } +template +/* static */ Literal LiteralUtil::CreateFull( + absl::Span dimensions, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), dimensions)); + literal.PopulateWithValue(value); + return literal; +} + template /* static */ absl::StatusOr LiteralUtil::CreateLiteralWithGenerator( const Shape& shape, diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index 627c652b78cf49..509bb9d2cfcf22 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -20,14 +20,12 @@ limitations under the License. #include #include "absl/log/log.h" -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/error_spec.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" -#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" @@ -128,31 +126,6 @@ XLA_TEST_F(CollectivePipelineParallelismTest, LiteralTestUtil::ExpectR2Equal({{0, 0}, {1, 1}}, results[3]); } -// Helper functions for pipeline parallelism tests where each stage scales the -// input by some factor. -absl::StatusOr CreateLinearLayerWeights(int64_t size, float factor) { - return LiteralUtil::CreateLiteralWithGenerator( - ShapeUtil::MakeShape(F32, {size, size}), - [&](absl::Span idx) -> float { - return idx[0] == idx[1] ? factor : 0.0; - }); -}; -absl::StatusOr CreateZeroInputR2(int64_t microbatches, int64_t size) { - return LiteralUtil::CreateLiteralWithGenerator( - ShapeUtil::MakeShape(F32, {microbatches, size}), - [&](absl::Span idx) -> float { return 0.0; }); -}; -absl::StatusOr CreateFingerprintInput(int64_t microbatches, - int64_t size, - float factor = 1.0) { - return LiteralUtil::CreateLiteralWithGenerator( - ShapeUtil::MakeShape(F32, {microbatches, size}), - [&](absl::Span idx) -> float { - float fingerprint = 1.0 * idx[0] + 0.0001 * idx[1]; - return factor * fingerprint; - }); -}; - // Naive implementation of pipeline parallelism: // - 4 devices // - 4 microbatches @@ -262,22 +235,18 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { // data by 1.0, 2.0, 3.0 and 4.0. The combined effect is to scale the input // data by 24.0. const int64_t kInputSize = 16; - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r0, - CreateLinearLayerWeights(kInputSize, 1.0)); - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r1, - CreateLinearLayerWeights(kInputSize, 2.0)); - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r2, - CreateLinearLayerWeights(kInputSize, 3.0)); - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r3, - CreateLinearLayerWeights(kInputSize, 4.0)); + Literal weights_r0 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 1.0); + Literal weights_r1 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 2.0); + Literal weights_r2 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 3.0); + Literal weights_r3 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 4.0); // Only the first replica holds the input to the pipeline in this naive // implementation. The remaining replicas get zero/dummy input. const int64_t kMicrobatches = 4; - TF_ASSERT_OK_AND_ASSIGN(Literal real_input, - CreateFingerprintInput(kMicrobatches, kInputSize)); - TF_ASSERT_OK_AND_ASSIGN(Literal fake_input, - CreateZeroInputR2(kMicrobatches, kInputSize)); + Literal real_input = + LiteralUtil::CreateFingerprintMatixR2(kMicrobatches, kInputSize); + Literal fake_input = + LiteralUtil::CreateFull({kMicrobatches, kInputSize}, 0.0); std::vector> args = {{&weights_r0, &real_input}, {&weights_r1, &fake_input}, @@ -291,9 +260,8 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { // Check pipeline output for last replica. // The combined effect of the pipeline is to scale the input data by 24.0. const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0; - TF_ASSERT_OK_AND_ASSIGN( - Literal expected_output, - CreateFingerprintInput(kMicrobatches, kInputSize, kExpectedFactor)); + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + kMicrobatches, kInputSize, kExpectedFactor); EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], ErrorSpec{1e-5, 1e-5})); } @@ -408,29 +376,24 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) { // data by 1.0, 2.0, 3.0 and 4.0. The combined effect is to scale the input // data by 24.0. const int64_t kInputSize = 16; - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r0, - CreateLinearLayerWeights(kInputSize, 1.0)); - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r1, - CreateLinearLayerWeights(kInputSize, 2.0)); - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r2, - CreateLinearLayerWeights(kInputSize, 3.0)); - TF_ASSERT_OK_AND_ASSIGN(Literal weights_r3, - CreateLinearLayerWeights(kInputSize, 4.0)); + Literal weights_r0 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 1.0); + Literal weights_r1 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 2.0); + Literal weights_r2 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 3.0); + Literal weights_r3 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 4.0); // Only the first replica holds the input to the pipeline in this naive // implementation. The remaining replicas get zero/dummy input. const int64_t kMicrobatches = 5; - TF_ASSERT_OK_AND_ASSIGN(Literal real_input, - CreateFingerprintInput(kMicrobatches, kInputSize)); - TF_ASSERT_OK_AND_ASSIGN(Literal fake_input, - CreateZeroInputR2(kMicrobatches, kInputSize)); + Literal real_input = + LiteralUtil::CreateFingerprintMatixR2(kMicrobatches, kInputSize); + Literal fake_input = + LiteralUtil::CreateFull({kMicrobatches, kInputSize}, 0.0); // Check pipeline output for last replica. // The combined effect of the pipeline is to scale the input data by 24.0. const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0; - TF_ASSERT_OK_AND_ASSIGN( - Literal expected_output, - CreateFingerprintInput(kMicrobatches, kInputSize, kExpectedFactor)); + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + kMicrobatches, kInputSize, /*scale=*/kExpectedFactor); std::vector> args = {{&weights_r0, &real_input}, {&weights_r1, &fake_input}, {&weights_r2, &fake_input}, From 3a5144f06523801bf55452da709739a7b0220b65 Mon Sep 17 00:00:00 2001 From: pizzud Date: Fri, 26 Jul 2024 10:02:17 -0700 Subject: [PATCH 201/376] hlo_evaluator: Don't dereference a disengaged optional. We can't look at error_detail if there's no value there. PiperOrigin-RevId: 656439229 --- xla/hlo/evaluator/BUILD | 2 ++ xla/hlo/evaluator/hlo_evaluator.cc | 48 ++++++++++++------------- xla/hlo/evaluator/hlo_evaluator.h | 20 +++++++++++ xla/hlo/evaluator/hlo_evaluator_test.cc | 25 +++++++++++++ 4 files changed, 70 insertions(+), 25 deletions(-) diff --git a/xla/hlo/evaluator/BUILD b/xla/hlo/evaluator/BUILD index dc7407cf9ce4ce..3963765132a5bc 100644 --- a/xla/hlo/evaluator/BUILD +++ b/xla/hlo/evaluator/BUILD @@ -144,8 +144,10 @@ xla_cc_test( "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:endian", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/xla/hlo/evaluator/hlo_evaluator.cc b/xla/hlo/evaluator/hlo_evaluator.cc index 25457eadbb3415..9b51dca7721011 100644 --- a/xla/hlo/evaluator/hlo_evaluator.cc +++ b/xla/hlo/evaluator/hlo_evaluator.cc @@ -205,25 +205,6 @@ absl::Status Apply(Literal& literal, F&& literal_generator) { literal.shape().element_type()); } -constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl"; - -// Use this class to represent the precise details of the error to enable -// special treatment. -enum class EvalErrorDetail : uint32_t { - // The evaluation result depends on dynamic values such as parameters and - // infeed. Therefore, the HLO's value cannot be statically evaluated. - kDynamicValueDependence = 0, -}; - -std::optional ParseEvalErrorDetail(const absl::Status& error) { - auto error_detail = error.GetPayload(kEvalErrorDetailUrl); - if (!error_detail.has_value() && error_detail->empty()) { - return std::nullopt; - } - return static_cast( - absl::little_endian::Load32(error_detail->Flatten().data())); -} - absl::Status MakeEvalErrorDueToParamOrInfeed( const HloInstruction& eval_instruction) { absl::Status error = absl::FailedPreconditionError(absl::StrCat( @@ -231,11 +212,12 @@ absl::Status MakeEvalErrorDueToParamOrInfeed( ") since it depends on infeed or parameters to its parent computation (", eval_instruction.parent()->name(), ").")); std::string error_payload; - error_payload.resize(sizeof(EvalErrorDetail)); + error_payload.resize(sizeof(internal::EvalErrorDetail)); absl::little_endian::Store32( const_cast(error_payload.data()), - static_cast(EvalErrorDetail::kDynamicValueDependence)); - error.SetPayload(kEvalErrorDetailUrl, absl::Cord(error_payload)); + static_cast( + internal::EvalErrorDetail::kDynamicValueDependence)); + error.SetPayload(internal::kEvalErrorDetailUrl, absl::Cord(error_payload)); return error; } @@ -263,10 +245,11 @@ std::optional GetInstructionValueAsInteger( } } - std::optional eval_error_detail = - ParseEvalErrorDetail(static_value.status()); + std::optional eval_error_detail = + internal::ParseEvalErrorDetail(static_value.status()); if (eval_error_detail.has_value() && - *eval_error_detail == EvalErrorDetail::kDynamicValueDependence) { + *eval_error_detail == + internal::EvalErrorDetail::kDynamicValueDependence) { return DynamicOrStaticInteger{std::nullopt}; } return std::nullopt; @@ -550,6 +533,21 @@ std::optional EvaluateWhileLoopParamInitValue( } // namespace +namespace internal { + +constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl"; + +std::optional ParseEvalErrorDetail(const absl::Status& error) { + auto error_detail = error.GetPayload(kEvalErrorDetailUrl); + if (!error_detail.has_value() || error_detail->empty()) { + return std::nullopt; + } + return static_cast( + absl::little_endian::Load32(error_detail->Flatten().data())); +} + +} // namespace internal + std::optional HandleNoopLoopCondition( const ParamIndexAndValue& parameter_index_and_value, const HloInstruction* while_operand, const HloComputation* while_body) { diff --git a/xla/hlo/evaluator/hlo_evaluator.h b/xla/hlo/evaluator/hlo_evaluator.h index 48864dd05c66bf..2f91c39c857c9c 100644 --- a/xla/hlo/evaluator/hlo_evaluator.h +++ b/xla/hlo/evaluator/hlo_evaluator.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "Eigen/Core" #include "xla/comparison_util.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tsl/platform/errors.h" #define _USE_MATH_DEFINES +#include #include #include #include @@ -515,6 +517,24 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { std::unique_ptr> MatmulArray2D(const Array2D& lhs, const Array2D& rhs); + +// Functionality exposed for testing. Do not rely on anything in this namespace +// outside this file. +namespace internal { + +// Use this class to represent the precise details of the error to enable +// special treatment. +enum class EvalErrorDetail : uint32_t { + // The evaluation result depends on dynamic values such as parameters and + // infeed. Therefore, the HLO's value cannot be statically evaluated. + kDynamicValueDependence = 0, +}; + +extern const absl::string_view kEvalErrorDetailUrl; + +std::optional ParseEvalErrorDetail(const absl::Status& error); + +} // namespace internal } // namespace xla #endif // XLA_HLO_EVALUATOR_HLO_EVALUATOR_H_ diff --git a/xla/hlo/evaluator/hlo_evaluator_test.cc b/xla/hlo/evaluator/hlo_evaluator_test.cc index 6be1d4c72f38ef..72dc6f84c4ade6 100644 --- a/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -27,8 +27,10 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/internal/endian.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -5802,5 +5804,28 @@ TEST_F(HloEvaluatorTest, SimpleConvTraced) { EXPECT_EQ(macs_traced, macs_expected); } +TEST(EvalErrorTest, OK) { + EXPECT_EQ(std::nullopt, internal::ParseEvalErrorDetail(absl::OkStatus())); +} + +TEST(EvalErrorTest, NoPayload) { + EXPECT_EQ(std::nullopt, + internal::ParseEvalErrorDetail(absl::InternalError("hmm"))); +} + +TEST(EvalErrorTest, Payload) { + absl::Status s = absl::InternalError("hmm"); + std::string payload; + payload.resize(sizeof(internal::EvalErrorDetail)); + absl::little_endian::Store32( + const_cast(payload.data()), + static_cast( + internal::EvalErrorDetail::kDynamicValueDependence)); + s.SetPayload(internal::kEvalErrorDetailUrl, absl::Cord(payload)); + + EXPECT_EQ(internal::ParseEvalErrorDetail(s), + internal::EvalErrorDetail::kDynamicValueDependence); +} + } // namespace } // namespace xla From 375a3bcacd3362ec4d098a63f40fdb832ef8b6fe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Jul 2024 10:08:19 -0700 Subject: [PATCH 202/376] Fix data race in coordination_service_test.cc PiperOrigin-RevId: 656442100 --- .../coordination/coordination_service_test.cc | 83 ++++++++++++------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 9d02ce3641d4eb..6133d19ef72380 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -516,21 +516,28 @@ TEST_F(CoordinateTwoTasksTest, EnableCoordinationService(/*has_service_to_client_connection=*/false); ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); - std::vector statuses; - statuses.reserve(2); - for (const CoordinatedTask& task : {task_0_, task_1_}) { - coord_service_->PollForErrorAsync( - task, [&](const absl::Status& status) { statuses.push_back(status); }); - } + // Use notifications to guarantee the ordering of operations across threads. + absl::Notification n0, n1; + + // The heartbeat error below should be propagated to all tasks. + absl::StatusCode expected_error_code = absl::StatusCode::kUnavailable; + coord_service_->PollForErrorAsync(task_0_, [&](const absl::Status& status) { + EXPECT_THAT(status, StatusIs(expected_error_code)); + n0.Notify(); + }); + coord_service_->PollForErrorAsync(task_1_, [&](const absl::Status& status) { + EXPECT_THAT(status, StatusIs(expected_error_code)); + n1.Notify(); + }); // No heartbeat for a while, leader consider the task as stale and propagate // the error to the tasks. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); - // The heartbeat error is propagated through error polling. - EXPECT_EQ(statuses.size(), 2); - EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kUnavailable))); + // Make sure the StatusCallbacks are called. + n0.WaitForNotification(); + n1.WaitForNotification(); } TEST_F(CoordinateTwoTasksTest, @@ -538,13 +545,19 @@ TEST_F(CoordinateTwoTasksTest, EnableCoordinationService(/*has_service_to_client_connection=*/false); ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); - std::vector statuses; - statuses.reserve(2); - - for (const CoordinatedTask& task : {task_0_, task_1_}) { - coord_service_->PollForErrorAsync( - task, [&](const absl::Status& status) { statuses.push_back(status); }); - } + // Use notifications to guarantee the ordering of operations across threads. + absl::Notification n0, n1; + + // The heartbeat error from `task_1_` below should be propagated to all tasks. + absl::StatusCode expected_error_code = absl::StatusCode::kUnavailable; + coord_service_->PollForErrorAsync(task_0_, [&](const absl::Status& status) { + EXPECT_THAT(status, StatusIs(expected_error_code, HasSubstr("task:1"))); + n0.Notify(); + }); + coord_service_->PollForErrorAsync(task_1_, [&](const absl::Status& status) { + EXPECT_THAT(status, StatusIs(expected_error_code, HasSubstr("task:1"))); + n1.Notify(); + }); // Use a factor of 0.9 to avoid accidental timeout. const int64_t sleeping_time = @@ -557,10 +570,9 @@ TEST_F(CoordinateTwoTasksTest, TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); Env::Default()->SleepForMicroseconds(sleeping_time); - // The heartbeat error is propagated through error polling. - EXPECT_EQ(statuses.size(), 2); - EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kUnavailable, - HasSubstr("task:")))); + // Make sure the StatusCallbacks are called. + n0.WaitForNotification(); + n1.WaitForNotification(); } TEST_F(CoordinateTwoTasksTest, ReportedErrorCanPropagateThroughErrorPolling) { @@ -1596,30 +1608,39 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { EnableCoordinationService(/*has_service_to_client_connection=*/false); ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); - std::vector statuses; - statuses.reserve(2); - for (const CoordinatedTask& task : {task_0_, task_1_}) { - coord_service_->PollForErrorAsync( - task, [&](const absl::Status& status) { statuses.push_back(status); }); - } + // Use notifications to guarantee the ordering of operations across threads. + absl::Notification n0, n1; + + // The heartbeat error below should be propagated to all tasks. + absl::StatusCode expected_error_code = absl::StatusCode::kUnavailable; + coord_service_->PollForErrorAsync(task_0_, [&](const absl::Status& status) { + EXPECT_THAT(status, StatusIs(expected_error_code)); + n0.Notify(); + }); + coord_service_->PollForErrorAsync(task_1_, [&](const absl::Status& status) { + EXPECT_THAT(status, StatusIs(expected_error_code)); + n1.Notify(); + }); + // No heartbeat for a while, leader consider the task as stale. The error will // be propagated through error polling. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); - EXPECT_EQ(statuses.size(), 2); - EXPECT_THAT(statuses, Each(StatusIs(absl::StatusCode::kUnavailable))); + // Make sure the StatusCallbacks are called before the barrier is called. + n0.WaitForNotification(); + n1.WaitForNotification(); - absl::Notification n0; + absl::Notification n_barrier; absl::Status barrier_status; // Barrier should fail when called after the error is propagated. coord_service_->BarrierAsync("barrier_id", absl::Seconds(5), task_0_, /*participating_tasks=*/{}, [&](absl::Status s) { barrier_status = s; - n0.Notify(); + n_barrier.Notify(); }); - n0.WaitForNotification(); + n_barrier.WaitForNotification(); EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; } From 65416b5387982d26aa9239ff58b40297d9bcb603 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 26 Jul 2024 10:30:41 -0700 Subject: [PATCH 203/376] [XLA:GPU] Keep outer custom call name in the profiler. PiperOrigin-RevId: 656450025 --- xla/service/gpu/fusions/custom.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/fusions/custom.cc b/xla/service/gpu/fusions/custom.cc index 30c92b9387f53a..3a95abfa402021 100644 --- a/xla/service/gpu/fusions/custom.cc +++ b/xla/service/gpu/fusions/custom.cc @@ -419,7 +419,7 @@ absl::StatusOr EmitGemm( GemmConfig::For(static_cast(&custom_call))); std::unique_ptr thunk; - auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(&custom_call); + auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(&fusion); if (absl::c_any_of(slice_instrs, [&](auto slice_instr) { return DynCastOrNull(slice_instr) != @@ -656,7 +656,7 @@ absl::StatusOr EmitCustomCall( } std::unique_ptr thunk; - auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(&custom_call); + auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(&fusion); auto ffi_thunk = [&](Slices ops, Slices res) { auto& called_computations = custom_call.called_computations(); From a58e6a83ae4b22eabdc2a29570d906c383a87b9e Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Fri, 26 Jul 2024 10:54:00 -0700 Subject: [PATCH 204/376] Add mutable_debug_options() in HloModuleConfig. PiperOrigin-RevId: 656458332 --- xla/service/hlo_module_config.h | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/service/hlo_module_config.h b/xla/service/hlo_module_config.h index a428c9bccba7a0..99408524ed32d1 100644 --- a/xla/service/hlo_module_config.h +++ b/xla/service/hlo_module_config.h @@ -226,6 +226,7 @@ class HloModuleConfig { std::string compilation_cache_key() const; const DebugOptions& debug_options() const { return debug_options_; } + DebugOptions& mutable_debug_options() { return debug_options_; } void set_debug_options(const DebugOptions& debug_options) { debug_options_ = debug_options; } From 4df18bd7e06e47da0723edf3944f343c40b36362 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Jul 2024 11:44:58 -0700 Subject: [PATCH 205/376] Rolling back due to breakages. Reverts 6f81609a96fce4f14d66f0cf96f0eed3e6107935 PiperOrigin-RevId: 656476970 --- xla/service/BUILD | 2 - xla/service/host_offloader.cc | 214 +------------------- xla/service/host_offloader.h | 19 +- xla/service/host_offloader_test.cc | 309 ----------------------------- 4 files changed, 3 insertions(+), 541 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index cc03a42fb2f019..09676869ddab91 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -6414,7 +6414,6 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -6443,7 +6442,6 @@ xla_cc_test( "//xla:util", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", diff --git a/xla/service/host_offloader.cc b/xla/service/host_offloader.cc index d2ab69ee684d7a..95c97e94c704da 100644 --- a/xla/service/host_offloader.cc +++ b/xla/service/host_offloader.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/host_offloader.h" #include -#include #include #include #include @@ -27,19 +26,15 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" -#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" #include "xla/service/call_graph.h" @@ -980,224 +975,17 @@ absl::StatusOr HostOffloader::ApplySchedulingFix( return changed; } -namespace { - -bool IsHostAsyncStart(const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kAsyncStart && - instruction->async_execution_thread() == HloInstruction::kHostThread; -} - -absl::Status ValidateAsyncComputationStructure(HloComputation* computation) { - for (HloInstruction* instr : computation->instructions()) { - if (instr->opcode() == HloOpcode::kParameter || instr->IsRoot()) { - continue; - } - - return absl::InternalError( - absl::StrCat("Unexpected instruction found in async computation: ", - instr->ToString())); - } - - return absl::OkStatus(); -} - -// Updates memory space for all outputs of the host offloaded computation -// (associated with `call_start`) that are ONLY used on host. NOTE: We also -// remove redundant copies to host, if any. -absl::StatusOr UpdateMemorySpaceForHostOffloadedOutputs( - HloInstruction* call_start, - absl::flat_hash_map>& - host_instr) { - // Keep track of MoveToHost instructions that need to be removed. - std::vector to_replace; - - HloComputation* called_computation = call_start->async_wrapped_computation(); - TF_RETURN_IF_ERROR(ValidateAsyncComputationStructure(called_computation)); - HloInstruction* root = called_computation->root_instruction(); - Shape* root_shape = root->mutable_shape(); - - for (auto& pair : host_instr) { - std::vector& instruction_and_shape_indexes = - pair.second; - - for (InstructionAndShapeIndex& instr_and_shape : - instruction_and_shape_indexes) { - // If instruction is MoveToHost, we will replace usage. - if (instr_and_shape.instruction->IsCustomCall( - host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { - to_replace.emplace_back(instr_and_shape); - continue; - } - - SetMemorySpace(ShapeUtil::GetMutableSubshape( - instr_and_shape.instruction->mutable_shape(), - instr_and_shape.shape_index), - Layout::kHostMemorySpace); - } - - // Update the memory space for the output of the computation call itself. - size_t index = pair.first; - SetMemorySpace(root_shape->mutable_tuple_shapes(index), - Layout::kHostMemorySpace); - } - - // Remove MoveToHost usage. - for (InstructionAndShapeIndex& instr_and_shape : to_replace) { - HloInstruction* pred = instr_and_shape.instruction->mutable_operand(0); - TF_RETURN_IF_ERROR(instr_and_shape.instruction->ReplaceAllUsesWith(pred)); - } - - return !host_instr.empty(); -} - -constexpr int64_t kShapeTupleOutputIndexInAsyncStart = 1; - -// Additional checks (does not run IsValidDuringPureMemoryOffload) to determine -// if the respective tensor can be on host. -bool ExtraCheckForValidUsageOnHostForHostOffloadedOutputs( - const Shape& entry_computation_shape, - InstructionAndShapeIndex& instruction_and_shape_index) { - HloInstruction* instruction = instruction_and_shape_index.instruction; - ShapeIndex& shape_index = instruction_and_shape_index.shape_index; - - // We respect entry computation layout. So for the cases where the - // outputs are not expected on host, we bail. - if (instruction->IsRoot() && instruction->parent()->IsEntryComputation()) { - if (ShapeUtil::GetSubshape(entry_computation_shape, shape_index) - .layout() - .memory_space() != Layout::kHostMemorySpace) { - return false; - } - } - - // For custom calls, we conservatively only accept MoveToHost. - // For MoveToDevice, this could be re-considered, or done as part of a - // generic redundant copies removal. - if (instruction->opcode() == HloOpcode::kCustomCall && - instruction->custom_call_target() != - host_memory_offload_annotations::kMoveToHostCustomCallTarget) { - return false; - } - - // TODO(b/347101407): To also consider host async computations, as we - // extend GetSuccessors to properly treat it. - if (instruction->opcode() == HloOpcode::kAsyncStart || - instruction->opcode() == HloOpcode::kAsyncDone) { - return false; - } - - return true; -} - -} // namespace - -absl::StatusOr HostOffloader::HandleRedundantCopiesBackToHost( - const HloModule* module, HloInstruction* instruction) { - HloAsyncInstruction* call_start = Cast(instruction); - - CHECK_EQ(call_start->users().size(), 1); - HloInstruction* call_done = call_start->users()[0]; - - absl::flat_hash_map> - host_instrs; - const Shape& entry_computation_shape = - module->entry_computation_layout().result_layout().shape(); - - // We collect all usages per output index, stopping at any non host - // instruction. - const Shape& done_shape = call_done->shape(); - for (size_t index = 0; index < done_shape.tuple_shapes_size(); index++) { - ShapeIndex output_shape_index = {static_cast(index)}; - std::queue queue; - queue.push(InstructionAndShapeIndex(call_done, output_shape_index)); - - // async-start packs the (inputs, outputs, context) in a tuple. - ShapeIndex start_shape_index = {kShapeTupleOutputIndexInAsyncStart, - static_cast(index)}; - - // TODO(b/347101407): Start from async-start and trace through the - // computation as well in GetSuccessors instead of having to manually add - // async-done and update the async computation separately. - host_instrs[index].push_back( - InstructionAndShapeIndex(call_start, start_shape_index)); - host_instrs[index].push_back( - InstructionAndShapeIndex(call_done, output_shape_index)); - - bool host_only = true; - // Keep track if the output of the host offloading computation is also an - // output of the entry computation. Temporaries are conservatively kept on - // HBM. - // - // TODO(b/347101407): Better use AliasAnalysis here to trace host compute - // outputs to entry compute outputs instead. NOTE: The current algorithm - // only tracks accepted host offloading operations which operate on the same - // tensor. - bool entry_compute_output = false; - - while (!queue.empty() && host_only) { - InstructionAndShapeIndex instruction_and_shape_index = queue.front(); - queue.pop(); - - TF_ASSIGN_OR_RETURN(std::vector successors, - GetSuccessors(InstructionAndShapeIndex( - instruction_and_shape_index.instruction, - instruction_and_shape_index.shape_index))); - - // Check if any of the successors needs to be on device. - for (InstructionAndShapeIndex& successor : successors) { - if (!IsValidDuringPureMemoryOffload(successor.instruction) || - !ExtraCheckForValidUsageOnHostForHostOffloadedOutputs( - entry_computation_shape, successor)) { - host_only = false; - break; - } - - if (successor.instruction->IsRoot() && - successor.instruction->parent()->IsEntryComputation()) { - entry_compute_output = true; - } - - queue.push(successor); - host_instrs[index].emplace_back(successor); - } - } - - if (!host_only || !entry_compute_output) { - host_instrs.erase(index); - } - } - - // Update memory space for the host_offloading outputs that never get used on - // device. - return UpdateMemorySpaceForHostOffloadedOutputs(call_start, host_instrs); -} - absl::StatusOr HostOffloader::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; - - // First remove redundant copies to and from host (conservatively) starting - // from the outputs of the host offloaded computations. Iterate over all - // instructions and look for XLA host offload annotations. - bool changed_in_loop; - for (HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { - for (HloInstruction* instruction : computation->instructions()) { - if (IsHostAsyncStart(instruction)) { - TF_ASSIGN_OR_RETURN(changed_in_loop, HandleRedundantCopiesBackToHost( - module, instruction)); - changed = changed || changed_in_loop; - } - } - } - TF_ASSIGN_OR_RETURN(const bool input_streaming_changed_module, HandleInputStreaming(module->entry_computation())); changed = changed || input_streaming_changed_module; // Since we're modifying the graph as we iterate over it, any time we change // it, we need to re-run the loop. + bool changed_in_loop; do { changed_in_loop = false; for (HloComputation* computation : diff --git a/xla/service/host_offloader.h b/xla/service/host_offloader.h index 994c40fb62bc7c..880cda3d77b621 100644 --- a/xla/service/host_offloader.h +++ b/xla/service/host_offloader.h @@ -67,17 +67,8 @@ bool operator==(const InstructionAndShapeIndex& lhs, // tensors along each path have their memory space set as host memory space. If // a MoveToHost custom call is paired with a DynamicUpdateSlice, the // DynamicUpdateSlice will write into host memory space. Otherwise, a copy from -// device to host will be inserted. -// -// If an output of a host offloaded computation is only used on host, the memory -// space of the usages are updated to reflect it and no copies to and from host -// are performed. Any MoveToHost instructions for outputs used only on host, are -// removed. -// TODO(b/347101407): A better approach could be to remove redundant copies in a -// generalized fashion. Should also be moved out of Host Offloader. -// -// All MoveToHost and MoveToDevice custom calls are removed by the end of this -// pass. +// device to host will be inserted. All MoveToHost and MoveToDevice custom calls +// are removed by the end of this pass. class HostOffloader : public HloModulePass { public: explicit HostOffloader(int64_t host_memory_space_color) @@ -176,12 +167,6 @@ class HostOffloader : public HloModulePass { absl::StatusOr ApplySchedulingFix( HloModule* module, const absl::flat_hash_set& execution_threads); - - // Starting from the outputs of the host offloaded computation, track all - // their usages. For the outputs that are ONLY used on host, remove redundant - // copies to and from host, as well as update the memory space. - absl::StatusOr HandleRedundantCopiesBackToHost( - const HloModule* module, HloInstruction* instruction); }; } // namespace xla diff --git a/xla/service/host_offloader_test.cc b/xla/service/host_offloader_test.cc index 46e8aa003d1666..85cc7742b3ce45 100644 --- a/xla/service/host_offloader_test.cc +++ b/xla/service/host_offloader_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -38,7 +37,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/util.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -51,7 +49,6 @@ namespace { class HostOffloaderTest : public HloTestBase { protected: static constexpr int64_t kHostMemorySpaceColor{5}; - static constexpr int64_t kHbmMemorySpaceColor{0}; absl::StatusOr RunHostOffloader(HloModule* module, bool after_layout = false) { @@ -3367,312 +3364,6 @@ ENTRY main { EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); } -TEST_F(HostOffloaderTest, BasicAsyncHostOffloadedCall_RemoveRedundantCopies) { - const std::string& hlo_string = R"( -HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(5)}, f32[4096]{0:S(5)})} - -%async_computation { - %param_0 = f32[4096] parameter(0) - ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" -}, execution_thread="host" - -ENTRY %main { - %a = f32[4096] parameter(0) - %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%a), async_execution_thread="host", calls=%async_computation - %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) - %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 - %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 - %gte_0_host = f32[4096] custom-call(%gte_0), custom_call_target="MoveToHost" - %gte_1_host = f32[4096] custom-call(%gte_1), custom_call_target="MoveToHost" - ROOT %tuple = (f32[4096], f32[4096]) tuple(%gte_0_host, %gte_1_host) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); - EXPECT_TRUE(changed); - - HloInstruction* async_start = FindInstruction(module.get(), "async-start"); - ASSERT_NE(async_start, nullptr); - HloInstruction* async_done = FindInstruction(module.get(), "async-done"); - ASSERT_NE(async_done, nullptr); - - HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); - ASSERT_NE(gte_0, nullptr); - TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); - HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); - ASSERT_NE(gte_1, nullptr); - TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); - - HloInstruction* gte_0_host = FindInstruction(module.get(), "gte_0_host"); - ASSERT_EQ(gte_0_host, nullptr); - HloInstruction* gte_1_host = FindInstruction(module.get(), "gte_1_host"); - ASSERT_EQ(gte_1_host, nullptr); - - // Check all set of successors. - HloInstruction* tuple = FindInstruction(module.get(), "tuple"); - ASSERT_NE(tuple, nullptr); - std::vector expected = {gte_0, gte_1}; - EXPECT_THAT(tuple->operands(), - ::testing::UnorderedElementsAreArray(expected)); -} - -TEST_F(HostOffloaderTest, - BasicAsyncHostOffloadedCall_NoChangesWhenEntryLayoutExpectsHBM) { - const std::string& hlo_string = R"( -HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(0)}, f32[4096]{0:S(0)})} - -%async_computation { - %param_0 = f32[4096] parameter(0) - ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" -}, execution_thread="host" - -ENTRY %main { - %a = f32[4096] parameter(0) - %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%a), async_execution_thread="host", calls=%async_computation - %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) - %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 - %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 - ROOT %tuple = (f32[4096], f32[4096]) tuple(%gte_0, %gte_1) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - - TF_ASSERT_OK(RunHostOffloader(module.get())); - - HloInstruction* async_start = FindInstruction(module.get(), "async-start"); - ASSERT_NE(async_start, nullptr); - HloInstruction* async_done = FindInstruction(module.get(), "async-done"); - ASSERT_NE(async_done, nullptr); - - HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); - ASSERT_NE(gte_0, nullptr); - TestShapeHasMemorySpace(gte_0->shape(), kHbmMemorySpaceColor); - HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); - ASSERT_NE(gte_1, nullptr); - TestShapeHasMemorySpace(gte_1->shape(), kHbmMemorySpaceColor); -} - -TEST_F(HostOffloaderTest, - BasicAsyncHostOffloadedCall_RemoveOnlyRedundantCopies) { - const std::string& hlo_string = R"( -HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(5)}, f32[4096]{0:S(5)})} - -%add { - %lhs = f32[] parameter(0) - %rhs = f32[] parameter(1) - ROOT %add_res = f32[] add(%lhs, %rhs) -} - -%async_computation { - %param_0 = f32[4096] parameter(0) - ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" -}, execution_thread="host" - -ENTRY %main { - %a = f32[4096] parameter(0) - %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%a), async_execution_thread="host", calls=%async_computation - %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) - %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 - %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 - %sum = f32[4096] add(%gte_0, %gte_0) - %gte_0_host = f32[4096] custom-call(%gte_0), custom_call_target="MoveToHost" - %gte_1_host = f32[4096] custom-call(%gte_1), custom_call_target="MoveToHost" - ROOT %tuple = (f32[4096]{0:S(5)}, f32[4096]) tuple(%gte_0_host, %gte_1_host) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); - EXPECT_TRUE(changed); - - HloInstruction* async_start = FindInstruction(module.get(), "async-start"); - ASSERT_NE(async_start, nullptr); - HloInstruction* async_done = FindInstruction(module.get(), "async-done"); - ASSERT_NE(async_done, nullptr); - - HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); - ASSERT_NE(gte_0, nullptr); - TestShapeHasMemorySpace(gte_0->shape(), kHbmMemorySpaceColor); - HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); - ASSERT_NE(gte_1, nullptr); - TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); - - // Since gte_0 is used on device (we do not take dead code into account here - // ..) gte_0 will be copied to device and be moved to host. - HloInstruction* gte_0_host = FindInstruction(module.get(), "gte_0_host"); - ASSERT_EQ(gte_0_host, nullptr); // replaced with copy - HloInstruction* copy = FindInstruction(module.get(), "copy"); - ASSERT_NE(copy, nullptr); - EXPECT_EQ(copy->operands()[0], gte_0); - - HloInstruction* gte_1_host = FindInstruction(module.get(), "gte_1_host"); - ASSERT_EQ(gte_1_host, nullptr); -} - -TEST_F(HostOffloaderTest, - AsyncHostOffloadedCall_nonEntryPoint_RemoveRedundantCopies) { - const std::string& hlo_string = R"( -HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(5)}, f32[4096]{0:S(5)})} - -%async_computation { - %param_0 = f32[4096] parameter(0) - ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" -}, execution_thread="host" - -%non_async_computation { - %param_0 = f32[4096] parameter(0) - %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%param_0), async_execution_thread="host", calls=%async_computation - %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) - %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 - %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 - %gte_0_host = f32[4096] custom-call(%gte_0), custom_call_target="MoveToHost" - %gte_1_host = f32[4096] custom-call(%gte_1), custom_call_target="MoveToHost" - ROOT %tuple_non_async = (f32[4096]{0:S(5)}, f32[4096]) tuple(%gte_0_host, %gte_1_host) -} - -ENTRY %main { - %a = f32[4096] parameter(0) - %call = (f32[4096], f32[4096]) call(%a), to_apply=%non_async_computation - %call_0 = f32[4096] get-tuple-element(%call), index=0 - %call_1 = f32[4096] get-tuple-element(%call), index=1 - ROOT %tuple = (f32[4096], f32[4096]) tuple(%call_0, %call_1) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); - EXPECT_TRUE(changed); - - HloInstruction* async_start = FindInstruction(module.get(), "async-start"); - ASSERT_NE(async_start, nullptr); - HloInstruction* async_done = FindInstruction(module.get(), "async-done"); - ASSERT_NE(async_done, nullptr); - - HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); - ASSERT_NE(gte_0, nullptr); - TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); - HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); - ASSERT_NE(gte_1, nullptr); - TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); - - HloInstruction* gte_0_host = FindInstruction(module.get(), "gte_0_host"); - ASSERT_EQ(gte_0_host, nullptr); - HloInstruction* gte_1_host = FindInstruction(module.get(), "gte_1_host"); - ASSERT_EQ(gte_1_host, nullptr); - - HloInstruction* tuple_non_async = - FindInstruction(module.get(), "tuple_non_async"); - ASSERT_NE(tuple_non_async, nullptr); - std::vector expected = {gte_0, gte_1}; - EXPECT_THAT(tuple_non_async->operands(), - ::testing::UnorderedElementsAreArray(expected)); - - // Check the main output is on host. - HloInstruction* tuple = FindInstruction(module.get(), "tuple"); - ASSERT_NE(tuple, nullptr); - TestShapeHasMemorySpace(tuple->shape().tuple_shapes(0), - kHostMemorySpaceColor); - TestShapeHasMemorySpace(tuple->shape().tuple_shapes(1), - kHostMemorySpaceColor); -} - -TEST_F(HostOffloaderTest, - AsyncHostOffloadedCall_passedToCall_RemoveRedundantCopies) { - const std::string& hlo_string = R"( -HloModule m, entry_computation_layout={(f32[4096]{0:S(5)})->(f32[4096]{0:S(5)}, f32[4096]{0:S(5)})} - -%async_computation { - %param_0 = f32[4096] parameter(0) - ROOT %offloaded-custom-call = (f32[4096], f32[4096]) custom-call(%param_0), custom_call_target="HostExecute" -}, execution_thread="host" - -%non_async_computation { - %param_0_non_async = f32[4096] parameter(0) - %param_1_non_async = f32[4096] parameter(1) - ROOT %tuple_non_async = (f32[4096], f32[4096]) tuple(%param_0_non_async, %param_1_non_async) -} - -ENTRY %main { - %a = f32[4096] parameter(0) - %async-start = ((f32[4096]), (f32[4096], f32[4096]), u32[]) async-start(%a), async_execution_thread="host", calls=%async_computation - %async-done = (f32[4096], f32[4096]) custom-call-done(%async-start) - %gte_0 = f32[4096] get-tuple-element(%async-done), index=0 - %gte_1 = f32[4096] get-tuple-element(%async-done), index=1 - %call = (f32[4096], f32[4096]) call(%gte_0, %gte_1), to_apply=%non_async_computation - %call_0 = f32[4096] get-tuple-element(%call), index=0 - %call_1 = f32[4096] get-tuple-element(%call), index=1 - %call_0_host = f32[4096] custom-call(%call_0), custom_call_target="MoveToHost" - %call_1_host = f32[4096] custom-call(%call_1), custom_call_target="MoveToHost" - ROOT %tuple = (f32[4096], f32[4096]) tuple(%call_0_host, %call_1_host) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); - EXPECT_TRUE(changed); - - HloInstruction* async_start = FindInstruction(module.get(), "async-start"); - ASSERT_NE(async_start, nullptr); - HloInstruction* async_done = FindInstruction(module.get(), "async-done"); - ASSERT_NE(async_done, nullptr); - - HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); - ASSERT_NE(gte_0, nullptr); - TestShapeHasMemorySpace(gte_0->shape(), kHostMemorySpaceColor); - HloInstruction* gte_1 = FindInstruction(module.get(), "gte_1"); - ASSERT_NE(gte_1, nullptr); - TestShapeHasMemorySpace(gte_1->shape(), kHostMemorySpaceColor); - - HloInstruction* call_0 = FindInstruction(module.get(), "call_0"); - ASSERT_NE(call_0, nullptr); - HloInstruction* call_1 = FindInstruction(module.get(), "call_1"); - ASSERT_NE(call_1, nullptr); - - HloInstruction* call_0_host = FindInstruction(module.get(), "call_0_host"); - ASSERT_EQ(call_0_host, nullptr); - HloInstruction* call_1_host = FindInstruction(module.get(), "call_1_host"); - ASSERT_EQ(call_1_host, nullptr); - - HloInstruction* param_0_non_async = - FindInstruction(module.get(), "param_0_non_async"); - ASSERT_NE(param_0_non_async, nullptr); - TestShapeHasMemorySpace(param_0_non_async->shape(), kHostMemorySpaceColor); - HloInstruction* param_1_non_async = - FindInstruction(module.get(), "param_1_non_async"); - ASSERT_NE(param_1_non_async, nullptr); - TestShapeHasMemorySpace(param_1_non_async->shape(), kHostMemorySpaceColor); - - HloInstruction* tuple_non_async = - FindInstruction(module.get(), "tuple_non_async"); - ASSERT_NE(tuple_non_async, nullptr); - std::vector expected_operands = {param_0_non_async, - param_1_non_async}; - EXPECT_THAT(tuple_non_async->operands(), - ::testing::UnorderedElementsAreArray(expected_operands)); - - HloInstruction* tuple = FindInstruction(module.get(), "tuple"); - ASSERT_NE(tuple, nullptr); - TestShapeHasMemorySpace(tuple->shape().tuple_shapes(0), - kHostMemorySpaceColor); - TestShapeHasMemorySpace(tuple->shape().tuple_shapes(1), - kHostMemorySpaceColor); - - std::vector expected = {call_0, call_1}; - EXPECT_THAT(tuple->operands(), - ::testing::UnorderedElementsAreArray(expected)); -} - } // namespace } // namespace xla From bb6a84c8deec487006a46adfd47ab9715652050d Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 26 Jul 2024 13:10:17 -0700 Subject: [PATCH 206/376] [IFRT] Add memory support to ShardingTest and generalize it into DeviceTest `ShardingTest` in `sharding_test_util` is now used at several IFRT tests when the tests uses IFRT devices without using real clients. This change extends it to support a "host" memory kind to enable memory-related tests, and generalize it as `DeviceTest` to match its current use cases beyond sharding tests. PiperOrigin-RevId: 656504683 --- xla/python/ifrt/BUILD | 19 +-- xla/python/ifrt/array_spec_test.cc | 6 +- .../ifrt/custom_call_program_serdes_test.cc | 6 +- xla/python/ifrt/device_test.cc | 6 +- xla/python/ifrt/device_test_util.cc | 153 ++++++++++++++++++ ...harding_test_util.h => device_test_util.h} | 16 +- xla/python/ifrt/remap_plan_test.cc | 6 +- xla/python/ifrt/sharding_serdes_test.cc | 6 +- xla/python/ifrt/sharding_test.cc | 22 +-- xla/python/ifrt/sharding_test_util.cc | 103 ------------ xla/python/pjrt_ifrt/BUILD | 4 +- .../pjrt_ifrt/xla_sharding_serdes_test.cc | 6 +- xla/python/pjrt_ifrt/xla_sharding_test.cc | 7 +- 13 files changed, 206 insertions(+), 154 deletions(-) create mode 100644 xla/python/ifrt/device_test_util.cc rename xla/python/ifrt/{sharding_test_util.h => device_test_util.h} (79%) delete mode 100644 xla/python/ifrt/sharding_test_util.cc diff --git a/xla/python/ifrt/BUILD b/xla/python/ifrt/BUILD index 20e61c21a55e7b..d140b795f9a944 100644 --- a/xla/python/ifrt/BUILD +++ b/xla/python/ifrt/BUILD @@ -230,8 +230,8 @@ xla_cc_test( size = "small", srcs = ["sharding_test.cc"], deps = [ + ":device_test_util", ":ifrt", - ":sharding_test_util", "//xla/python/ifrt/ir:sharding_param", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", @@ -260,16 +260,17 @@ cc_library( ) cc_library( - name = "sharding_test_util", + name = "device_test_util", testonly = True, - srcs = ["sharding_test_util.cc"], - hdrs = ["sharding_test_util.h"], + srcs = ["device_test_util.cc"], + hdrs = ["device_test_util.h"], deps = [ ":ifrt", ":mock", ":test_util", "//xla:util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -465,11 +466,11 @@ xla_cc_test( name = "sharding_serdes_test", srcs = ["sharding_serdes_test.cc"], deps = [ + ":device_test_util", ":ifrt", ":serdes", ":serdes_proto_cc", ":sharding_serdes", - ":sharding_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:statusor", @@ -492,9 +493,9 @@ xla_cc_test( srcs = ["array_spec_test.cc"], deps = [ ":array_spec_proto_cc", + ":device_test_util", ":ifrt", ":sharding_serdes", - ":sharding_test_util", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", @@ -513,8 +514,8 @@ xla_cc_test( srcs = ["device_test.cc"], deps = [ ":device_proto_cc", + ":device_test_util", ":ifrt", - ":sharding_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", @@ -553,9 +554,9 @@ xla_cc_test( size = "small", srcs = ["remap_plan_test.cc"], deps = [ + ":device_test_util", ":ifrt", ":sharding_serdes", - ":sharding_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", @@ -695,10 +696,10 @@ xla_cc_test( deps = [ ":custom_call_program", ":custom_call_program_serdes", + ":device_test_util", ":ifrt", ":program_serdes", ":serdes", - ":sharding_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", diff --git a/xla/python/ifrt/array_spec_test.cc b/xla/python/ifrt/array_spec_test.cc index 0dfa4d327de562..37aec4b35648b4 100644 --- a/xla/python/ifrt/array_spec_test.cc +++ b/xla/python/ifrt/array_spec_test.cc @@ -20,18 +20,18 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "xla/python/ifrt/array_spec.pb.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_test_util.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/sharding_test_util.h" #include "tsl/platform/statusor.h" namespace xla { namespace ifrt { namespace { -class ArraySpecTest : public test_util::ShardingTest {}; +class ArraySpecTest : public test_util::DeviceTest {}; TEST_P(ArraySpecTest, ToFromProto) { auto device_list = GetDevices({0, 1}); @@ -64,7 +64,7 @@ TEST_P(ArraySpecTest, ToFromProto) { } INSTANTIATE_TEST_SUITE_P(NumDevices, ArraySpecTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/2, /*num_addressable_devices=*/2})); diff --git a/xla/python/ifrt/custom_call_program_serdes_test.cc b/xla/python/ifrt/custom_call_program_serdes_test.cc index e942c4bce5d4ef..31a259378695cc 100644 --- a/xla/python/ifrt/custom_call_program_serdes_test.cc +++ b/xla/python/ifrt/custom_call_program_serdes_test.cc @@ -25,13 +25,13 @@ limitations under the License. #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/custom_call_program.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_test_util.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/program_serdes.h" #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/sharding_test_util.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -45,7 +45,7 @@ using ::testing::MatchesRegex; using ::testing::SizeIs; using ::tsl::testing::StatusIs; -class CustomCallProgramSerDesTest : public test_util::ShardingTest {}; +class CustomCallProgramSerDesTest : public test_util::DeviceTest {}; TEST_P(CustomCallProgramSerDesTest, RoundTrip) { Shape shape0({10, 20}); @@ -117,7 +117,7 @@ TEST_P(CustomCallProgramSerDesTest, RoundTrip) { } INSTANTIATE_TEST_SUITE_P(NumDevices, CustomCallProgramSerDesTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/2, /*num_addressable_devices=*/2})); diff --git a/xla/python/ifrt/device_test.cc b/xla/python/ifrt/device_test.cc index cef9d05d8347d8..85fae8d7e4f2d0 100644 --- a/xla/python/ifrt/device_test.cc +++ b/xla/python/ifrt/device_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/blocking_counter.h" #include "xla/python/ifrt/device.pb.h" -#include "xla/python/ifrt/sharding_test_util.h" +#include "xla/python/ifrt/device_test_util.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" @@ -34,7 +34,7 @@ namespace xla { namespace ifrt { namespace { -class DeviceListTest : public test_util::ShardingTest {}; +class DeviceListTest : public test_util::DeviceTest {}; TEST_P(DeviceListTest, ToFromProto) { auto device_list = GetDevices({0, 1}); @@ -89,7 +89,7 @@ TEST_P(DeviceListTest, EqualityTest) { } INSTANTIATE_TEST_SUITE_P(NumDevices, DeviceListTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/2, /*num_addressable_devices=*/2})); diff --git a/xla/python/ifrt/device_test_util.cc b/xla/python/ifrt/device_test_util.cc new file mode 100644 index 00000000000000..a40e3a7423dbd1 --- /dev/null +++ b/xla/python/ifrt/device_test_util.cc @@ -0,0 +1,153 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/device_test_util.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/test_util.h" +#include "xla/util.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace ifrt { +namespace test_util { + +namespace { + +using ::testing::Return; +using ::testing::ReturnPointee; +using ::testing::ReturnRef; + +// Internal state of a client for device tests. +struct DeviceTestClientState { + Client* client; + + // Shared MemoryKind objects. + MemoryKind host_memory_kind; + + // Mapping from a memory ID to the mock memory object. + absl::flat_hash_map> memory_map; + std::vector memories; + std::vector> memory_devices; + + // Mapping from a device ID to the mock device object. + absl::flat_hash_map> device_map; + // Raw pointers to mock devices. + std::vector devices; + std::vector addressable_devices; + std::vector> device_memories; +}; + +// Creates a mock client for device tests. The client will have a specified +// number of fake addressable and non-addressable devices. Client implements +// `devices()` and `LookupDevice()`. Device implements `id()`, with an +// arbitrary deterministic device ids assigned. Each device has "host" memory +// (which is also its default memory), and each memory has a single device. +std::shared_ptr MakeDeviceTestClient(int num_devices, + int num_addressable_devices) { + CHECK_GE(num_devices, num_addressable_devices); + auto state = std::make_shared(); + + state->host_memory_kind = MemoryKind("host"); + + state->memory_map.reserve(num_devices); + state->memories.reserve(num_devices); + state->memory_devices.resize(num_devices); + + state->device_map.reserve(num_devices); + state->devices.reserve(num_devices); + state->addressable_devices.reserve(num_addressable_devices); + state->device_memories.resize(num_devices); + + for (int i = 0; i < num_devices; ++i) { + const bool addressable = i < num_addressable_devices; + auto memory = std::make_unique(); + ON_CALL(*memory, Id).WillByDefault(Return(MemoryId(i + 10))); + ON_CALL(*memory, Kind).WillByDefault(ReturnRef(state->host_memory_kind)); + // memory_devices will be filled in at the end of the loop. + ON_CALL(*memory, Devices) + .WillByDefault(ReturnPointee(&state->memory_devices[i])); + state->memories.push_back(memory.get()); + state->memory_map.insert({MemoryId(i + 10), std::move(memory)}); + + auto device = std::make_unique(); + // client will be filled in at the end of the loop. + ON_CALL(*device, client).WillByDefault(ReturnPointee(&state->client)); + ON_CALL(*device, Id).WillByDefault(Return(DeviceId(i + 10))); + ON_CALL(*device, IsAddressable).WillByDefault(Return(addressable)); + ON_CALL(*device, DebugString) + .WillByDefault(Return(absl::StrCat("device(", i + 10, ")"))); + ON_CALL(*device, DefaultMemory).WillByDefault(Return(state->memories[i])); + // device_memories will be filled in at the end of the loop. + ON_CALL(*device, Memories) + .WillByDefault(ReturnPointee(&state->device_memories[i])); + state->devices.push_back(device.get()); + if (addressable) { + state->addressable_devices.push_back(device.get()); + } + state->device_map.insert({DeviceId(i + 10), std::move(device)}); + + state->device_memories[i] = absl::MakeConstSpan(&state->memories[i], 1); + state->memory_devices[i] = absl::MakeConstSpan(&state->devices[i], 1); + } + + auto client = std::make_shared(); + state->client = client.get(); + ON_CALL(*client, devices) + .WillByDefault( + [state]() -> absl::Span { return state->devices; }); + ON_CALL(*client, addressable_devices) + .WillByDefault([state]() -> absl::Span { + return state->addressable_devices; + }); + ON_CALL(*client, LookupDevice) + .WillByDefault([state](DeviceId device_id) -> absl::StatusOr { + auto it = state->device_map.find(device_id); + if (it == state->device_map.end()) { + return InvalidArgument("Unexpected device id: %d", device_id.value()); + } + return it->second.get(); + }); + ON_CALL(*client, GetTopologyForDevices).WillByDefault([](const DeviceList&) { + return nullptr; + }); + return client; +} + +} // namespace + +void DeviceTest::SetUp() { + const auto [num_devices, num_addressable_devices] = GetParam(); + client_ = MakeDeviceTestClient(num_devices, num_addressable_devices); +} + +DeviceList DeviceTest::GetDevices(absl::Span device_indices) { + return test_util::GetDevices(client_.get(), device_indices).value(); +} + +} // namespace test_util +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/sharding_test_util.h b/xla/python/ifrt/device_test_util.h similarity index 79% rename from xla/python/ifrt/sharding_test_util.h rename to xla/python/ifrt/device_test_util.h index b3118118209270..86f2611eb96a80 100644 --- a/xla/python/ifrt/sharding_test_util.h +++ b/xla/python/ifrt/device_test_util.h @@ -13,28 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_IFRT_SHARDING_TEST_UTIL_H_ -#define XLA_PYTHON_IFRT_SHARDING_TEST_UTIL_H_ +#ifndef XLA_PYTHON_IFRT_DEVICE_TEST_UTIL_H_ +#define XLA_PYTHON_IFRT_DEVICE_TEST_UTIL_H_ #include +#include "absl/types/span.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" #include "tsl/platform/test.h" namespace xla { namespace ifrt { namespace test_util { -// Parameters for ShardingTest. +// Parameters for DeviceTest. // Requests `num_devices` total devices, where `num_addressable_devices` of them // are addressable, and the rest of devices are non-addressable. -struct ShardingTestParam { +struct DeviceTestParam { int num_devices; int num_addressable_devices; }; -// Test fixture for sharding tests. -class ShardingTest : public testing::TestWithParam { +// Test fixture for device tests. +class DeviceTest : public testing::TestWithParam { public: void SetUp() override; Client* client() { return client_.get(); } @@ -52,4 +54,4 @@ class ShardingTest : public testing::TestWithParam { } // namespace ifrt } // namespace xla -#endif // XLA_PYTHON_IFRT_SHARDING_TEST_UTIL_H_ +#endif // XLA_PYTHON_IFRT_DEVICE_TEST_UTIL_H_ diff --git a/xla/python/ifrt/remap_plan_test.cc b/xla/python/ifrt/remap_plan_test.cc index 69865ad659419e..b7e53ffe5983f7 100644 --- a/xla/python/ifrt/remap_plan_test.cc +++ b/xla/python/ifrt/remap_plan_test.cc @@ -25,11 +25,11 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_test_util.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/sharding_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -43,7 +43,7 @@ using ::testing::HasSubstr; using ::testing::SizeIs; using ::tsl::testing::StatusIs; -class RemapPlanTest : public test_util::ShardingTest {}; +class RemapPlanTest : public test_util::DeviceTest {}; TEST_P(RemapPlanTest, ToFromProto) { RemapPlan plan; @@ -408,7 +408,7 @@ TEST_P(RemapPlanTest, InvalidOutputDevices) { } INSTANTIATE_TEST_SUITE_P(NumDevices, RemapPlanTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/4, /*num_addressable_devices=*/4})); diff --git a/xla/python/ifrt/sharding_serdes_test.cc b/xla/python/ifrt/sharding_serdes_test.cc index d1d075aa4f50b2..be6508a35c67c8 100644 --- a/xla/python/ifrt/sharding_serdes_test.cc +++ b/xla/python/ifrt/sharding_serdes_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include #include "absl/functional/bind_front.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device_test_util.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/serdes.pb.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/sharding_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -34,7 +34,7 @@ namespace { using ::testing::ElementsAreArray; -class ShardingSerDesTest : public test_util::ShardingTest {}; +class ShardingSerDesTest : public test_util::DeviceTest {}; TEST_P(ShardingSerDesTest, SingleDeviceShardingRoundTrip) { auto sharding = SingleDeviceSharding::Create( @@ -138,7 +138,7 @@ TEST_P(ShardingSerDesTest, ConcreteEvenShardingRoundTrip) { } INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingSerDesTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/2, /*num_addressable_devices=*/2})); diff --git a/xla/python/ifrt/sharding_test.cc b/xla/python/ifrt/sharding_test.cc index 43004e36ad18fe..a5c4bf7d15df2c 100644 --- a/xla/python/ifrt/sharding_test.cc +++ b/xla/python/ifrt/sharding_test.cc @@ -24,12 +24,12 @@ limitations under the License. #include #include "llvm/Support/Casting.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_test_util.h" #include "xla/python/ifrt/index.h" #include "xla/python/ifrt/index_domain.h" #include "xla/python/ifrt/ir/sharding_param.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" -#include "xla/python/ifrt/sharding_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -45,11 +45,11 @@ using ::testing::SizeIs; using ::tsl::testing::IsOkAndHolds; using ::tsl::testing::StatusIs; -class SingleDeviceShardingTest : public test_util::ShardingTest {}; -class OpaqueShardingTest : public test_util::ShardingTest {}; -class ConcreteShardingTest : public test_util::ShardingTest {}; -class ConcreteEvenShardingTest : public test_util::ShardingTest {}; -class ShardingParamShardingTest : public test_util::ShardingTest {}; +class SingleDeviceShardingTest : public test_util::DeviceTest {}; +class OpaqueShardingTest : public test_util::DeviceTest {}; +class ConcreteShardingTest : public test_util::DeviceTest {}; +class ConcreteEvenShardingTest : public test_util::DeviceTest {}; +class ShardingParamShardingTest : public test_util::DeviceTest {}; TEST_P(SingleDeviceShardingTest, IsFullyReplicated) { auto device_list = GetDevices({0}); @@ -817,23 +817,23 @@ TEST_P(ShardingParamShardingTest, IndexDomainWithReplication) { } INSTANTIATE_TEST_SUITE_P(NumDevices, SingleDeviceShardingTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, /*num_addressable_devices=*/6})); INSTANTIATE_TEST_SUITE_P(NumDevices, OpaqueShardingTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, /*num_addressable_devices=*/6})); INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteShardingTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, /*num_addressable_devices=*/6})); INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteEvenShardingTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, /*num_addressable_devices=*/6})); INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingParamShardingTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, /*num_addressable_devices=*/4})); diff --git a/xla/python/ifrt/sharding_test_util.cc b/xla/python/ifrt/sharding_test_util.cc deleted file mode 100644 index db8199a1ed3476..00000000000000 --- a/xla/python/ifrt/sharding_test_util.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/python/ifrt/sharding_test_util.h" - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "xla/python/ifrt/device.h" -#include "xla/python/ifrt/mock.h" -#include "xla/python/ifrt/test_util.h" -#include "xla/util.h" -#include "tsl/platform/test.h" - -namespace xla { -namespace ifrt { -namespace test_util { - -namespace { - -using ::testing::Return; - -// Internal state of a client for sharding tests. -struct ShardingTestClientState { - // Mapping from a device ID to the mock device object. - absl::flat_hash_map> device_map; - // Raw pointers to mock devices. - std::vector devices; -}; - -// Creates a mock client for sharding tests. The client will have a specified -// number of fake addressable and non-addressable devices. Client implements -// `devices()` and `LookupDevice()`. Device implements `id()`, with an arbitrary -// deterministic device ids assigned. -std::shared_ptr MakeShardingTestClient( - int num_devices, int num_addressable_devices) { - auto state = std::make_shared(); - state->device_map.reserve(num_devices); - state->devices.reserve(num_devices); - - for (int i = 0; i < num_addressable_devices; ++i) { - auto device = std::make_unique(); - ON_CALL(*device, Id).WillByDefault(Return(DeviceId(i + 10))); - ON_CALL(*device, IsAddressable).WillByDefault(Return(true)); - ON_CALL(*device, DebugString) - .WillByDefault(Return(absl::StrCat("device(", i + 10, ")"))); - state->devices.push_back(device.get()); - state->device_map.insert({DeviceId(i + 10), std::move(device)}); - } - for (int i = num_addressable_devices; i < num_devices; ++i) { - auto device = std::make_unique(); - ON_CALL(*device, Id).WillByDefault(Return(DeviceId(i + 10))); - ON_CALL(*device, IsAddressable).WillByDefault(Return(false)); - state->devices.push_back(device.get()); - state->device_map.insert({DeviceId(i + 10), std::move(device)}); - } - - auto client = std::make_shared(); - ON_CALL(*client, devices) - .WillByDefault( - [state]() -> absl::Span { return state->devices; }); - ON_CALL(*client, LookupDevice) - .WillByDefault([state](DeviceId device_id) -> absl::StatusOr { - auto it = state->device_map.find(device_id); - if (it == state->device_map.end()) { - return InvalidArgument("Unexpected device id: %d", device_id.value()); - } - return it->second.get(); - }); - return client; -} - -} // namespace - -void ShardingTest::SetUp() { - const auto [num_devices, num_addressable_devices] = GetParam(); - client_ = MakeShardingTestClient(num_devices, num_addressable_devices); -} - -DeviceList ShardingTest::GetDevices(absl::Span device_indices) { - return test_util::GetDevices(client_.get(), device_indices).value(); -} - -} // namespace test_util -} // namespace ifrt -} // namespace xla diff --git a/xla/python/pjrt_ifrt/BUILD b/xla/python/pjrt_ifrt/BUILD index ce9502d19f0dd5..62c118aa67dee2 100644 --- a/xla/python/pjrt_ifrt/BUILD +++ b/xla/python/pjrt_ifrt/BUILD @@ -117,9 +117,9 @@ xla_cc_test( ":xla_sharding_serdes", "//xla/hlo/ir:hlo", "//xla/python/ifrt", + "//xla/python/ifrt:device_test_util", "//xla/python/ifrt:serdes", "//xla/python/ifrt:sharding_serdes", - "//xla/python/ifrt:sharding_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", @@ -174,7 +174,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", "//xla/python/ifrt", - "//xla/python/ifrt:sharding_test_util", + "//xla/python/ifrt:device_test_util", "//xla/python/ifrt:tuple_impl_test_lib", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", diff --git a/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc index c3d6f011975eca..b5affee221f6b1 100644 --- a/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc +++ b/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device_test_util.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/sharding_test_util.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" #include "tsl/platform/statusor.h" @@ -35,7 +35,7 @@ namespace { using ::testing::ElementsAreArray; -class XlaShardingSerDesTest : public test_util::ShardingTest {}; +class XlaShardingSerDesTest : public test_util::DeviceTest {}; TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { auto device_list = GetDevices({0, 1}); @@ -56,7 +56,7 @@ TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { } INSTANTIATE_TEST_SUITE_P(NumDevices, XlaShardingSerDesTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ .num_devices = 2, .num_addressable_devices = 2})); } // namespace diff --git a/xla/python/pjrt_ifrt/xla_sharding_test.cc b/xla/python/pjrt_ifrt/xla_sharding_test.cc index 881eb75b78ec1e..177e7dead03944 100644 --- a/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_sharding.h" -#include #include #include #include @@ -25,12 +24,12 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/python/ifrt/device_test_util.h" #include "xla/python/ifrt/index.h" #include "xla/python/ifrt/index_domain.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "xla/python/ifrt/sharding_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" @@ -47,7 +46,7 @@ using ::testing::SizeIs; using ::tsl::testing::IsOkAndHolds; using ::tsl::testing::StatusIs; -class HloShardingTest : public test_util::ShardingTest {}; +class HloShardingTest : public test_util::DeviceTest {}; TEST_P(HloShardingTest, IsFullyReplicated) { auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); @@ -482,7 +481,7 @@ TEST_P(HloShardingTest, DisassembleFailsWithDynamicShape) { } INSTANTIATE_TEST_SUITE_P(NumDevices, HloShardingTest, - testing::Values(test_util::ShardingTestParam{ + testing::Values(test_util::DeviceTestParam{ .num_devices = 6, .num_addressable_devices = 4})); } // namespace From d66fab2f6f820dc91844a92c273af97e0fe14981 Mon Sep 17 00:00:00 2001 From: Harsha H S Date: Fri, 26 Jul 2024 13:57:05 -0700 Subject: [PATCH 207/376] PR #15369: [ROCm] Fix build break due to 7cad716 Imported from GitHub PR https://github.com/openxla/xla/pull/15369 Copybara import of the project: -- a1f189a0ff7ef5f88ed45d1e8738b22a4d0029a5 by Harsha HS : [ROCm] Fix build break due to 7cad716 Merging this change closes #15369 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15369 from ROCm:ci_fix_build_break_20240726 a1f189a0ff7ef5f88ed45d1e8738b22a4d0029a5 PiperOrigin-RevId: 656519619 --- xla/service/gpu/BUILD | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 5eac3b82ebf1d3..1b97905c38eaa8 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2907,10 +2907,10 @@ xla_cc_test( cc_library( name = "custom_kernel_fusion_autotuner", - srcs = if_cuda_is_configured(["custom_kernel_fusion_autotuner.cc"]), - hdrs = if_cuda_is_configured(["custom_kernel_fusion_autotuner.h"]), + srcs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.cc"]), + hdrs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.h"]), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = if_cuda_is_configured([ + deps = if_gpu_is_configured([ ":autotuner_compile_util", ":autotuner_util", ":backend_configs_cc", From ba30db721cf865cb7d3b537a8fc4bd80ea457b8c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Jul 2024 14:06:27 -0700 Subject: [PATCH 208/376] [xla:gpu] Add experimental flag to enable command buffers while profiling is active. PiperOrigin-RevId: 656522860 --- xla/debug_options_flags.cc | 10 ++ xla/service/gpu/ir_emitter_unnested.cc | 4 +- xla/service/gpu/runtime/BUILD | 5 +- .../gpu/runtime/command_buffer_thunk.cc | 13 +- .../gpu/runtime/command_buffer_thunk.h | 7 +- .../gpu/runtime/command_buffer_thunk_test.cc | 111 ++++++++++++++++++ xla/xla.proto | 7 +- 7 files changed, 149 insertions(+), 8 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 9b71d26ac6217a..750fb560cfa7fc 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -281,6 +281,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_autotune_gemm_rtol(0.1f); + opts.set_xla_enable_command_buffers_during_profiling(false); + return opts; } @@ -1830,6 +1832,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "version checks must be done by the user (e.g. if you want to use " "separate caches for different versions of XLA, please use different " "directories). Default: no cache.")); + flag_list->push_back(tsl::Flag( + "xla_enable_command_buffers_during_profiling", + bool_setter_for( + &DebugOptions::set_xla_enable_command_buffers_during_profiling), + debug_options->xla_enable_command_buffers_during_profiling(), + "Experimental: Enable command buffers while a profiling active. " + "By default, enabling profiling switches from command buffers to " + "op-by-op mode.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc index 9082eaca99b815..a964f6bbd9d72c 100644 --- a/xla/service/gpu/ir_emitter_unnested.cc +++ b/xla/service/gpu/ir_emitter_unnested.cc @@ -580,7 +580,9 @@ absl::Status IrEmitterUnnested::EmitCommandBufferThunk( AddThunkToThunkSequence(std::make_unique( std::move(cmd_sequence), Thunk::ThunkInfo::WithProfileAnnotation(instr), - std::move(thunk_sequence))); + std::move(thunk_sequence), + ir_emitter_context_->debug_options() + .xla_enable_command_buffers_during_profiling())); return absl::OkStatus(); } diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD index 5899d8e2999450..3b2345c858b91e 100644 --- a/xla/service/gpu/runtime/BUILD +++ b/xla/service/gpu/runtime/BUILD @@ -429,10 +429,10 @@ cc_library( deps = [ ":annotation", ":command_buffer_cmd", + ":sequential_thunk", # build_cleaner: keep ":thunk", "//xla/service:buffer_assignment", # build_cleaner: keep "//xla/service/gpu:buffer_allocations", # build_cleaner: keep - "//xla/service/gpu/runtime:sequential_thunk", # build_cleaner: keep "//xla/stream_executor", "//xla/stream_executor:command_buffer", "@com_google_absl//absl/base:core_headers", @@ -466,6 +466,8 @@ xla_test( deps = [ ":command_buffer_cmd", ":command_buffer_thunk", + ":memset_thunk", + ":sequential_thunk", ":thunk", "//xla:shape_util", "//xla:types", @@ -491,6 +493,7 @@ xla_test( "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", + "@tsl//tsl/profiler/lib:profiler_lock", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", ]), diff --git a/xla/service/gpu/runtime/command_buffer_thunk.cc b/xla/service/gpu/runtime/command_buffer_thunk.cc index a37913a57e352a..42d14071fcf4e1 100644 --- a/xla/service/gpu/runtime/command_buffer_thunk.cc +++ b/xla/service/gpu/runtime/command_buffer_thunk.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/runtime/annotation.h" #include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" @@ -54,12 +55,15 @@ CommandBufferThunk::ExecutorCommandBuffer::ExecutorCommandBuffer( std::unique_ptr command_buffer) : command_buffer(std::move(command_buffer)) {} -CommandBufferThunk::CommandBufferThunk(CommandBufferCmdSequence commands, - ThunkInfo thunk_info, - std::unique_ptr thunks) +CommandBufferThunk::CommandBufferThunk( + CommandBufferCmdSequence commands, ThunkInfo thunk_info, + std::unique_ptr thunks, + bool enable_command_buffers_during_profiling) : Thunk(Thunk::kCommandBuffer, std::move(thunk_info)), commands_(std::move(commands)), thunks_(std::move(thunks)), + enable_command_buffers_during_profiling_( + enable_command_buffers_during_profiling), state_(std::make_shared()) { // When we create a new command buffer thunk (which happens when we // instantiate a new Gpu executable) we evict command buffers for all @@ -199,7 +203,8 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { // TODO(b/290773547): Profiler (CUPTI) + CUDA graphs lead to memory // corruption. As a work around disable command buffers (CUDA graphs) and run // everything in op-by-op mode. - if (tsl::profiler::ProfilerLock::HasActiveSession() && thunks_) { + if (tsl::profiler::ProfilerLock::HasActiveSession() && thunks_ && + !enable_command_buffers_during_profiling_) { VLOG(1) << "Execute command buffer thunk as a regular thunk sequence " "because we detected active profiling session"; TF_RETURN_IF_ERROR(thunks_->ExecuteOnStream(params)); diff --git a/xla/service/gpu/runtime/command_buffer_thunk.h b/xla/service/gpu/runtime/command_buffer_thunk.h index a3cb4672c951b3..a0442f3d711023 100644 --- a/xla/service/gpu/runtime/command_buffer_thunk.h +++ b/xla/service/gpu/runtime/command_buffer_thunk.h @@ -38,7 +38,8 @@ namespace xla::gpu { class CommandBufferThunk : public Thunk { public: CommandBufferThunk(CommandBufferCmdSequence commands, ThunkInfo thunk_info, - std::unique_ptr thunks = nullptr); + std::unique_ptr thunks = nullptr, + bool enable_command_buffers_during_profiling = false); const std::unique_ptr& thunks() const { return thunks_; } @@ -128,6 +129,10 @@ class CommandBufferThunk : public Thunk { // bugs that lead to memory corruption when CUPTI traces CUDA graph execution. std::unique_ptr thunks_; + // When true, allows command buffers to be used while profiling active. + // TODO(b/355487968): Remove this option when validation complete. + bool enable_command_buffers_during_profiling_; + // Command buffer thunk state allocated in heap to allow global (per-process) // management of instantiated command buffers. std::shared_ptr state_; diff --git a/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/xla/service/gpu/runtime/command_buffer_thunk_test.cc index 60c7a3e3ddfd85..9146213d72fe89 100644 --- a/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -30,6 +30,8 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include "xla/service/gpu/runtime/memset_thunk.h" +#include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" @@ -51,6 +53,7 @@ limitations under the License. #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#include "tsl/profiler/lib/profiler_lock.h" #ifdef GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" @@ -250,6 +253,114 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { ASSERT_EQ(dst, std::vector(4, 84)); } +TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersDisabledDuringProfiling) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42 + se::DeviceMemory a = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + + auto memset_thunk = + std::make_unique(Thunk::ThunkInfo(), 84, slice_a); + std::vector> thunks; + thunks.push_back(std::move(memset_thunk)); + auto seq_thunks = + std::make_unique(Thunk::ThunkInfo(), std::move(thunks)); + // Prepare commands sequence for constructing command buffer that should not + // be used. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_a, int32_t{12}); + + constexpr bool kProfileCommandBuffersEnabled = false; + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(), + std::move(seq_thunks), + kProfileCommandBuffersEnabled); + + ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a}, 0, &allocator); + + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); + + TF_ASSERT_OK_AND_ASSIGN(auto profiler_lock, + tsl::profiler::ProfilerLock::Acquire()); + // Execute command buffer thunk and verify that it set the memory. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `a` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 84)); +} + +TEST(CommandBufferThunkTest, Memset32CmdCommandBuffersEnabledDuringProfiling) { + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=42 + se::DeviceMemory a = executor->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 42, byte_length)); + + // Prepare buffer allocations for recording command buffer. + BufferAllocation alloc_a(/*index=*/0, byte_length, /*color=*/0); + BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); + + auto memset_thunk = + std::make_unique(Thunk::ThunkInfo(), 84, slice_a); + std::vector> thunks; + thunks.push_back(std::move(memset_thunk)); + auto seq_thunks = + std::make_unique(Thunk::ThunkInfo(), std::move(thunks)); + // Prepare commands sequence for constructing command buffer that should not + // be used. + CommandBufferCmdSequence commands; + commands.Emplace(s0, slice_a, int32_t{12}); + + constexpr bool kProfileCommandBuffersEnabled = true; + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(), + std::move(seq_thunks), + kProfileCommandBuffersEnabled); + + ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a}, 0, &allocator); + + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); + + TF_ASSERT_OK_AND_ASSIGN(auto profiler_lock, + tsl::profiler::ProfilerLock::Acquire()); + // Execute command buffer thunk and verify that it set the memory. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `a` data back to host. + std::vector dst(4, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length)); + + ASSERT_EQ(dst, std::vector(4, 12)); +} + TEST(CommandBufferThunkTest, Memset32CmdOnDifferentStreams) { se::StreamExecutor* executor = GpuExecutor(); diff --git a/xla/xla.proto b/xla/xla.proto index ad53966df160da..35e795b1680df8 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -873,7 +873,12 @@ message DebugOptions { // Relative precision for comparing different GEMM solutions float xla_gpu_autotune_gemm_rtol = 316; - // Next id: 317 + // Allow launching command buffers while profiling active. + // When disabled, execute in op-by-op mode. + // TODO(b/355487968): Remove this option when validation complete. + bool xla_enable_command_buffers_during_profiling = 317; + + // Next id: 318 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From cf139009c9c30a7ee9d4b5085aaf0b6938da0b2b Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Fri, 26 Jul 2024 14:11:16 -0700 Subject: [PATCH 209/376] Remove cuda_library rule that's missing srcs. PiperOrigin-RevId: 656524392 --- xla/service/gpu/kernels/BUILD | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/xla/service/gpu/kernels/BUILD b/xla/service/gpu/kernels/BUILD index 6e0a0d44d1523b..9d04094c5fd7cc 100644 --- a/xla/service/gpu/kernels/BUILD +++ b/xla/service/gpu/kernels/BUILD @@ -412,17 +412,6 @@ cuda_library( ]), ) -cuda_library( - name = "cutlass_gemm_kernel_bf16xbf16_to_f32", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc"]), - copts = ["-Wno-unknown-attributes"], - deps = if_cuda_is_configured([ - ":cutlass_gemm_adaptor", - "@local_config_cuda//cuda:cuda_headers", - "@cutlass_archive//:cutlass", - ]), -) - #===--------------------------------------------------------------------------------------------===# # CUTLASS Gemm kernel libraries #===--------------------------------------------------------------------------------------------===# From 044f1b4ea3ebc76ef7031bb3968003f69df49bed Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 26 Jul 2024 14:22:33 -0700 Subject: [PATCH 210/376] [xla:cpu] Add support for exporting executables without jit-compiled kernels PiperOrigin-RevId: 656527910 --- xla/service/cpu/BUILD | 1 + xla/service/cpu/cpu_compiler.cc | 32 +++++-- .../cpu/cpu_instruction_fusion_test.cc | 11 ++- xla/service/cpu/tests/cpu_aot_export_test.cc | 94 +++++++++++-------- 4 files changed, 86 insertions(+), 52 deletions(-) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index cf93548c61563d..cb2bf71ad8b440 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1386,6 +1386,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 73201a657bc3f8..7254f2b1380f03 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -1710,6 +1710,8 @@ CpuExecutableAotCompilationResult::LoadExecutable( std::unique_ptr module, HloModule::CreateFromProtoWithConfig(proto_.hlo_module())); + VLOG(2) << "Load XLA:CPU executable for module: " << module->name(); + // Recreate BufferAssignment from proto. TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, @@ -1732,10 +1734,17 @@ CpuExecutableAotCompilationResult::LoadExecutable( // Create a named buffer from compiled object file. llvm::StringRef data(proto_.obj_file().data(), proto_.obj_file().size()); - auto obj_file = - llvm::MemoryBuffer::getMemBuffer(data, proto_.entry_function_name()); - cantFail((*jit)->AddObjFile(std::move(obj_file))); + // We might have an XLA:CPU executable that has only runtime thunks and + // doesn't have any corresponding object files. + if (data.empty()) { + VLOG(2) << "Loaded XLA:CPU executable does not have an object file"; + } else { + VLOG(2) << "Load XLA:CPU executable object file with entry function: " + << proto_.entry_function_name(); + cantFail((*jit)->AddObjFile( + llvm::MemoryBuffer::getMemBuffer(data, proto_.entry_function_name()))); + } std::unique_ptr cpu_executable; @@ -1822,18 +1831,23 @@ absl::StatusOr> CpuCompiler::Export( if (!cpu_executable) return Internal("Could not downcast Executable to CpuExecutable"); - if (cpu_executable->obj_files().size() != 1) { - return absl::InternalError( - absl::StrCat("Can't export CPU execuable, expected exactly one object " - "file but got: ", - cpu_executable->obj_files().size())); + if (cpu_executable->obj_files().size() > 1) { + return Internal( + "Can't export CPU executable %s, expected at most one object file but " + "got: %d", + cpu_executable->module().name(), cpu_executable->obj_files().size()); } + std::string_view obj_file = cpu_executable->obj_files().empty() + ? std::string_view("") + : cpu_executable->obj_files()[0]; + auto kind = cpu_executable->has_thunks() ? CompilationResultProto::KERNELS : CompilationResultProto::CLASSIC; + return {std::make_unique( &cpu_executable->module(), &cpu_executable->buffer_assignment(), - cpu_executable->module_name(), cpu_executable->obj_files()[0], kind)}; + cpu_executable->module_name(), obj_file, kind)}; } absl::StatusOr> diff --git a/xla/service/cpu/cpu_instruction_fusion_test.cc b/xla/service/cpu/cpu_instruction_fusion_test.cc index 5db0bebaaa9a2e..54cfb41ba39487 100644 --- a/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -20,18 +20,20 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/transpose_folding.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" +#include "tsl/platform/statusor.h" namespace op = xla::testing::opcode_matchers; -namespace xla { -namespace cpu { +namespace xla::cpu { namespace { using InstructionFusionTest = HloTestBase; @@ -453,7 +455,6 @@ TEST_F(OpcodeFusionTest, Slice_Negate) { HloInstruction::CreateSlice(slice_shape, param0, {0}, {4}, {2})); builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2}), HloOpcode::kNegate, slice1)); - auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -928,6 +929,6 @@ ENTRY main { EXPECT_TRUE(fused_something); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion()); } + } // namespace -} // namespace cpu -} // namespace xla +} // namespace xla::cpu diff --git a/xla/service/cpu/tests/cpu_aot_export_test.cc b/xla/service/cpu/tests/cpu_aot_export_test.cc index 39528634f7cd82..5f023fe2f03b6f 100644 --- a/xla/service/cpu/tests/cpu_aot_export_test.cc +++ b/xla/service/cpu/tests/cpu_aot_export_test.cc @@ -32,10 +32,45 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" -namespace xla { -namespace cpu { - -using CpuAotCompilationTest = HloTestBase; +namespace xla::cpu { + +class CpuAotCompilationTest : public HloTestBase { + protected: + void ExportAndLoad(std::string_view hlo_string) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto compiler = backend().compiler(); + auto name = absl::AsciiStrToUpper( + PlatformUtil::CanonicalPlatformName("host").value()); + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + se::PlatformManager::PlatformWithName(name)); + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, + platform->ExecutorForDevice(0)); + + // JIT compile executable + auto module_group = std::make_unique(std::move(module)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector> executables, + compiler->Compile(std::move(module_group), {{stream_exec}}, nullptr)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr exported_aot_result, + compiler->Export(executables[0].get())); + + // Serialize-deserialize AOT compilation result. + TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result, + exported_aot_result->SerializeAsString()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr loaded_aot_result, + compiler->LoadAotCompilationResult(serialized_aot_result)); + + // Load Executable from AOT compilation result. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + loaded_aot_result->LoadExecutable(compiler, stream_exec)); + } +}; TEST_F(CpuAotCompilationTest, ExportAndLoadExecutable) { const absl::string_view hlo_string = R"( @@ -46,39 +81,22 @@ TEST_F(CpuAotCompilationTest, ExportAndLoadExecutable) { ROOT b = f32[2, 2]{1,0} add(a, a) })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - - auto compiler = backend().compiler(); - auto name = absl::AsciiStrToUpper( - PlatformUtil::CanonicalPlatformName("host").value()); - TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - se::PlatformManager::PlatformWithName(name)); - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, - platform->ExecutorForDevice(0)); - - // JIT compile executable - auto module_group = std::make_unique(std::move(module)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector> executables, - compiler->Compile(std::move(module_group), {{stream_exec}}, nullptr)); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr exported_aot_result, - compiler->Export(executables[0].get())); - - // Serialize-deserialize AOT compilation result. - TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result, - exported_aot_result->SerializeAsString()); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr loaded_aot_result, - compiler->LoadAotCompilationResult(serialized_aot_result)); - - // Load Executable from AOT compilation result. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, - loaded_aot_result->LoadExecutable(compiler, stream_exec)); + ExportAndLoad(hlo_string); +} + +TEST_F(CpuAotCompilationTest, ExportAndLoadExecutableNoKernels) { + // Copy operation implemented in the runtime and this module does not have + // any jit compiled kernels. We test that we still can export and load such + // executable. + const absl::string_view hlo_string = R"( + HloModule Test + + ENTRY main { + a = f32[2, 2]{1,0} parameter(0) + ROOT b = f32[2, 2]{1,0} copy(a) + })"; + + ExportAndLoad(hlo_string); } -} // namespace cpu -} // namespace xla +} // namespace xla::cpu From 1353fc9ea8557d0895afbeb796a424ad5e87f9d3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Jul 2024 14:24:40 -0700 Subject: [PATCH 211/376] Seen performance regression. Reverts a8425caae53c138b9ce89528a8d186a51baa718a PiperOrigin-RevId: 656528511 --- xla/service/spmd/spmd_partitioner.cc | 32 ------------- xla/service/spmd/spmd_partitioner_test.cc | 56 ++++++++--------------- 2 files changed, 18 insertions(+), 70 deletions(-) diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index daaf297f6885e4..c3fc8b1ab31c0a 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -2543,38 +2543,6 @@ absl::Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) { } absl::Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) { - bool operands_same_sharding = true; - for (int64_t i = 1; i < hlo->operand_count(); ++i) { - if (hlo->operand(i)->sharding() != hlo->operand(0)->sharding()) { - operands_same_sharding = false; - break; - } - } - - if (hlo->operand_count() > 1 && operands_same_sharding) { - // Do the element-wise operation. Then reshard the result to the specified - // sharding. - std::vector original_operands; - for (HloInstruction* operand : hlo->operands()) { - original_operands.push_back(GetPartitionedHlo(operand).hlo()); - } - - HloInstruction* result_with_operand_sharding = - b_.AddInstruction(hlo->CloneWithNewOperands( - MakePartitionedShape(hlo->shape(), hlo->operand(0)->sharding()), - original_operands)); - result_with_operand_sharding->set_sharding(hlo->operand(0)->sharding()); - SetPartitionedHlo(hlo, [&] { - return PartitionedHlo(result_with_operand_sharding, hlo->shape(), - MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return absl::OkStatus(); - } - - // Reshard the operands to the result's sharding. Then do the element-wise - // operation. std::vector new_operands; for (HloInstruction* operand : hlo->operands()) { new_operands.push_back( diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 0b4fcd373e1ae4..99a1d1f92dc951 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -652,30 +652,6 @@ ENTRY entry { op::Reshape(), op::Constant())))); } -TEST_P(SpmdPartitioningTest, TiledElementwiseOperandsSameSharding) { - absl::string_view hlo_string = R"( -HloModule module - -ENTRY entry { - a = f32[32,32] parameter(0), sharding={devices=[2,2]<=[4]} - b = f32[32,32] parameter(1), sharding={devices=[2,2]<=[4]} - c = f32[32,32] multiply(a, b), sharding={replicated} - ROOT d = f32[32,32] add(c, c), sharding={devices=[4,1]<=[4]} -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - PartitionComputation(hlo_string, /*num_devices=*/4)); - VLOG(1) << module->ToString(); - - auto multiply = FindInstruction(module.get(), "c.1"); - EXPECT_NE(multiply, nullptr); - EXPECT_THAT(multiply, op::Shape("f32[16,16]")); - EXPECT_THAT(multiply, op::Multiply(op::Parameter(0), op::Parameter(1))); - - HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Shape("f32[8,32]")); - EXPECT_THAT(root, op::DynamicSlice(op::Add(), _, _)); -} - TEST_P(SpmdPartitioningTest, TiledAllReduce) { absl::string_view hlo_string = R"( HloModule module @@ -9149,9 +9125,9 @@ ENTRY entry { constant.1 = f32[6,3]{1,0} constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}), sharding={devices=[1,2,2]<=[4] last_tile_dims={manual}} - multiply = f32[6,3]{1,0} multiply(constant, constant.1), + multiply = f32[6,3]{1,0} multiply(constant, constant.1), sharding={devices=[1,2,2]<=[4] last_tile_dims={manual}} - ROOT add = f32[6,3]{1,0} add(multiply, constant.1), + ROOT add = f32[6,3]{1,0} add(multiply, constant.1), sharding={devices=[1,1,2,2]<=[4] last_tile_dims={replicated, manual}} } )"; @@ -9170,12 +9146,13 @@ ENTRY entry { op::Constant(), op::Reshape())); auto multiply = AllOf(op::Shape("f32[6,2]"), op::Multiply(multiply_lhs, multiply_rhs)); - auto add = AllOf(op::Shape("f32[6,2]"), op::Add(multiply, multiply_rhs)); + auto replicated_lhs = + AllOf(op::Shape("f32[6,3]"), + op::Slice(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), multiply, op::Constant(), op::Reshape())))); const auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT( - root, AllOf(op::Shape("f32[6,3]"), - op::Slice(op::AllReduce(op::DynamicUpdateSlice( - op::Broadcast(), add, op::Constant(), op::Reshape()))))); + EXPECT_THAT(root, AllOf(op::Shape("f32[6,3]"), + op::Add(replicated_lhs, op::Constant()))); } TEST_P(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_ReplicateToTile) { @@ -9189,9 +9166,9 @@ ENTRY entry { constant.1 = f32[6,3]{1,0} constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}), sharding={devices=[1,1,2,2]<=[4] last_tile_dims={replicated,manual}} - multiply = f32[6,3]{1,0} multiply(constant, constant.1), + multiply = f32[6,3]{1,0} multiply(constant, constant.1), sharding={devices=[1,1,2,2]<=[4] last_tile_dims={replicated,manual}} - ROOT add = f32[6,3]{1,0} add(multiply, constant.1), + ROOT add = f32[6,3]{1,0} add(multiply, constant.1), sharding={devices=[1,2,2]<=[4] last_tile_dims={manual}} } )"; @@ -9202,11 +9179,14 @@ ENTRY entry { auto multiply = AllOf(op::Shape("f32[6,3]"), op::Multiply(op::Constant(), op::Constant())); - auto add = AllOf(op::Shape("f32[6,3]"), op::Add(multiply, op::Constant())); - const auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, AllOf(op::Shape("f32[6,2]"), - op::DynamicSlice(op::Pad(add, op::Constant()), - op::Constant(), op::Reshape()))); + auto add_lhs = AllOf(op::Shape("f32[6,2]"), + op::DynamicSlice(op::Pad(multiply, op::Constant()), + op::Constant(), op::Reshape())); + auto add_rhs = AllOf(op::Shape("f32[6,2]"), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Constant(), op::Reshape())); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[6,2]"), op::Add(add_lhs, add_rhs))); } TEST_P(SpmdPartitioningTest, From ba10aacbf9da62d1315d0578e9c9affabdf32aa7 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 26 Jul 2024 14:36:06 -0700 Subject: [PATCH 212/376] [xla:cpu] Add support for sorting inputs with non-descending layouts PiperOrigin-RevId: 656531960 --- xla/service/cpu/runtime/sort_thunk.cc | 34 +++++----- xla/service/cpu/runtime/sort_thunk_test.cc | 72 ++++++++++++++++++++++ 2 files changed, 92 insertions(+), 14 deletions(-) diff --git a/xla/service/cpu/runtime/sort_thunk.cc b/xla/service/cpu/runtime/sort_thunk.cc index 51d7c3ede0f624..a24a2272587b25 100644 --- a/xla/service/cpu/runtime/sort_thunk.cc +++ b/xla/service/cpu/runtime/sort_thunk.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" @@ -61,7 +62,7 @@ static absl::Status VerifySortInputs(absl::Span inputs, return Internal("Inputs must not be empty"); } - // All inputs must have the same shape (ignoring element type) and layout. + // All inputs must have the same shape and layout (ignoring element type). auto equal = Shape::Equal().IgnoreElementType(); const Shape& shape = inputs[0].shape; @@ -80,12 +81,6 @@ static absl::Status VerifySortInputs(absl::Span inputs, absl::StrJoin(shape.dimensions(), ","), dimension); } - // We support only monotonic layouts with dim0 major. - if (!LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) { - return Internal("Unsupported sort input layout %s", - shape.ToString(/*print_layout=*/true)); - } - return absl::OkStatus(); } @@ -340,18 +335,29 @@ struct SortDims { // We sort `outer_dim_size * inner_dim_size` vectors of length // `sort_dim_size`, by iterating over `data` memory and calling `std::sort` // (or `std::stable_sort`) on each (strided) slice of the buffer. -static SortDims GetSortDims(absl::Span dimensions, - int64_t dimension) { +static SortDims GetSortDims(const Shape& shape, int64_t dimension) { int64_t sort_dimension = - dimension >= 0 ? dimension : dimensions.size() + dimension; + dimension >= 0 ? dimension : shape.rank() + dimension; + + // We need to normalize shape + layout into a descending layout, so that we + // can compute access strides according to the physical layout. + Shape physical_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape); + + // Map `sort_dimension` from logical to physical. + auto logical_to_physical = LayoutUtil::MakeLogicalToPhysical(shape.layout()); + sort_dimension = logical_to_physical[sort_dimension]; auto product = [](absl::Span dims) { return absl::c_accumulate(dims, int64_t{1}, std::multiplies<>()); }; - int64_t outer_dim_size = product(dimensions.subspan(0, dimension)); + // Use physical dimensions to compute access strides. + absl::Span dimensions = physical_shape.dimensions(); + + int64_t outer_dim_size = product(dimensions.subspan(0, sort_dimension)); int64_t sort_dim_size = dimensions[sort_dimension]; - int64_t inner_dim_size = product(dimensions.subspan(dimension + 1)); + int64_t inner_dim_size = product(dimensions.subspan(sort_dimension + 1)); int64_t num_iterations = outer_dim_size * inner_dim_size; return SortDims{outer_dim_size, sort_dim_size, inner_dim_size, @@ -398,7 +404,7 @@ static absl::Status SortInplace(absl::Span data, SortThunk::LessThan* less_than) { // All inputs have the same dimensions and layout, so we can use the first // shape to get the sort dimensions. - SortDims sort_dims = GetSortDims(shapes[0].dimensions(), dimension); + SortDims sort_dims = GetSortDims(shapes[0], dimension); // Iterate over all the 1-dimensional slices of the buffers and sort them. for (int64_t i = 0; i < sort_dims.num_iterations; ++i) { @@ -499,7 +505,7 @@ tsl::AsyncValueRef SortThunk::Execute( data.back().size()); VLOG(3) << absl::StreamFormat(" sort input #%d: %s in slice %s (%p)", idx, - input.shape.ToString(), + input.shape.ToString(/*print_layout=*/true), input.slice.ToString(), data.back().opaque()); } diff --git a/xla/service/cpu/runtime/sort_thunk_test.cc b/xla/service/cpu/runtime/sort_thunk_test.cc index 81da44f07f7c01..4c7b2514a1c709 100644 --- a/xla/service/cpu/runtime/sort_thunk_test.cc +++ b/xla/service/cpu/runtime/sort_thunk_test.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/buffer_allocations.h" #include "xla/service/cpu/runtime/thunk.h" @@ -165,6 +167,76 @@ TEST_P(SortThunkTest, Sort2D) { EXPECT_EQ(indices, expected_indices); } +TEST_P(SortThunkTest, Sort2DWithLayout) { + bool is_stable = GetParam(); + + std::vector buffers; + std::vector data = {4.0, 3.0, 2.0, 1.0}; + std::vector indices = {0, 1, 2, 3}; + + size_t size_in_bytes = data.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(indices.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation alloc0(0, size_in_bytes, 0); + BufferAllocation alloc1(1, size_in_bytes, 0); + + BufferAllocation::Slice slice0(&alloc0, 0, size_in_bytes); + BufferAllocation::Slice slice1(&alloc1, 0, size_in_bytes); + + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + *data_shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + Shape indices_shape = ShapeUtil::MakeShape(S32, {2, 2}); + *indices_shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + // Sort along the dimension `0`. + TF_ASSERT_OK_AND_ASSIGN( + auto sort_dim0, + SortThunk::Create({"sort"}, + {{slice0, data_shape}, {slice1, indices_shape}}, + /*dimension=*/0, is_stable, "less_than")); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + LessThanComparator less_than_comparator; + params.function_registry = &less_than_comparator; + + auto execute_event0 = sort_dim0->Execute(params); + tsl::BlockUntilReady(execute_event0); + ASSERT_FALSE(execute_event0.IsError()); + + std::vector expected_data = {3.0, 4.0, 1.0, 2.0}; + std::vector expected_indices = {1, 0, 3, 2}; + + EXPECT_EQ(data, expected_data); + EXPECT_EQ(indices, expected_indices); + + // Reset data and indices to make it unsorted along the dimension `1`. + data = {2.0, 4.0, 1.0, 3.0}; + indices = {0, 1, 2, 3}; + + TF_ASSERT_OK_AND_ASSIGN( + auto sort_dim1, + SortThunk::Create({"sort"}, + {{slice0, data_shape}, {slice1, indices_shape}}, + /*dimension=*/1, + /*is_stable=*/false, "less_than")); + + auto execute_event1 = sort_dim1->Execute(params); + tsl::BlockUntilReady(execute_event1); + ASSERT_FALSE(execute_event1.IsError()); + + expected_data = {1.0, 3.0, 2.0, 4.0}; + expected_indices = {2, 3, 0, 1}; + + EXPECT_EQ(data, expected_data); + EXPECT_EQ(indices, expected_indices); +} + INSTANTIATE_TEST_SUITE_P(SortThunk, SortThunkTest, testing::Bool(), testing::PrintToStringParamName()); From 475924e49f3440ec411e2e1afd90944bcf521614 Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Fri, 26 Jul 2024 15:27:34 -0700 Subject: [PATCH 213/376] Improve how HostOffloadLegalize moves copies out of host-memory-only offloading. When a copy is moved over a shape changing bitcast or a dynamic-slice/slice, the shape of the copy also needs to change. PiperOrigin-RevId: 656546789 --- xla/service/host_offload_legalize.cc | 103 +++++++++++++++++++--- xla/service/host_offload_legalize_test.cc | 40 +++++++++ 2 files changed, 133 insertions(+), 10 deletions(-) diff --git a/xla/service/host_offload_legalize.cc b/xla/service/host_offload_legalize.cc index 1199112adeb9f7..bc850dd18d91d2 100644 --- a/xla/service/host_offload_legalize.cc +++ b/xla/service/host_offload_legalize.cc @@ -385,6 +385,13 @@ void UpdateInstructionLayout(const InstructionAndIndex& instruction_and_index, } } +Shape RemoveMajormostDimension(const Shape& shape) { + CHECK(shape.has_layout()) << "Shape must have layout."; + const int size = shape.layout().minor_to_major_size(); + const int64_t majormost_dim = shape.layout().minor_to_major(size - 1); + return ShapeUtil::DeleteDimension(majormost_dim, shape); +} + absl::Status MoveCopy( const InstructionAndIndex& copy_to_move_instruction_and_index, const CallGraph* call_graph, @@ -392,13 +399,27 @@ absl::Status MoveCopy( absl::flat_hash_set& to_remove) { HloInstruction* copy_to_move = copy_to_move_instruction_and_index.instruction; VLOG(5) << "Moving copy: " << copy_to_move->ToString(); - std::vector stack = {copy_to_move_instruction_and_index}; + struct InstructionAndShapes { + InstructionAndShapes(InstructionAndIndex idx, Shape s_before, Shape s_after) + : instruction_and_index(idx), + shape_before_copy(s_before), + shape_after_copy(s_after) {} + InstructionAndIndex instruction_and_index; + Shape shape_before_copy; + Shape shape_after_copy; + }; + std::vector stack = {InstructionAndShapes( + copy_to_move_instruction_and_index, copy_to_move->operand(0)->shape(), + copy_to_move->shape())}; while (!stack.empty()) { - InstructionAndIndex current_instruction_and_index = stack.back(); + InstructionAndShapes current_instruction_and_shapes = stack.back(); + InstructionAndIndex current_instruction_and_index = + current_instruction_and_shapes.instruction_and_index; stack.pop_back(); VLOG(5) << "Current top of stack: " << current_instruction_and_index.instruction->ToString() << " " << current_instruction_and_index.index; + // Get the users of the current instruction. absl::StatusOr> current_value_down = WalkDownMemoryOffload(current_instruction_and_index, *call_graph); if (!current_value_down.ok()) { @@ -406,14 +427,74 @@ absl::Status MoveCopy( << current_value_down.status(); break; } + for (InstructionAndIndex& instruction_and_index : current_value_down.value()) { HloInstruction* instruction = instruction_and_index.instruction; + Shape shape_before_copy = + current_instruction_and_shapes.shape_before_copy; + Shape shape_after_copy = current_instruction_and_shapes.shape_after_copy; VLOG(5) << "Evaluating successor: " << instruction->ToString(); const int index = instruction_and_index.index; + if (instruction->opcode() == HloOpcode::kBitcast) { + // For now, we only know how to move a copy over a bitcast which + // "reshapes" away the majormost dimension (which must be a degenerate + // dimension). + const Shape& before_bitcast_shape = instruction->operand(0)->shape(); + const Shape& after_bitcast_shape = instruction->shape(); + if (!Shape::Equal().IgnoreLayout()(copy_to_move->operand(0)->shape(), + copy_to_move->shape())) { + return absl::InternalError(absl::StrFormat( + "Expecting copy to only change instructions layout. Copy: %s", + copy_to_move->ToString())); + } + if (after_bitcast_shape.rank() != before_bitcast_shape.rank() - 1) { + return absl::InternalError( + absl::StrFormat("Only handling bitcasts which remove 0'th " + "dimension. This bitcast is \"%s\"", + instruction->ToString())); + } + if (!(ShapeUtil::IsEffectivelyMostMajorDimension(before_bitcast_shape, + 0) && + before_bitcast_shape.dimensions(0) == 1)) { + return absl::InternalError( + absl::StrFormat("Only handling bitcasts with majormost dimension " + "of size 1. This bitcast is \"%s\"", + instruction->ToString())); + } + const Shape new_bitcast_shape = + RemoveMajormostDimension(shape_before_copy); + VLOG(2) << absl::StreamFormat( + " Encountered bitcast \"%s\", updating current shape from %s to %s", + instruction->name(), shape_before_copy.ToString(true), + new_bitcast_shape.ToString(true)); + shape_before_copy = new_bitcast_shape; + const Shape new_copy_shape = RemoveMajormostDimension(shape_after_copy); + VLOG(2) << absl::StreamFormat( + " Also updating shape after copy from %s to %s", + shape_after_copy.ToString(true), new_copy_shape.ToString(true)); + shape_after_copy = new_copy_shape; + } else if (instruction->opcode() == HloOpcode::kSlice || + instruction->opcode() == HloOpcode::kDynamicSlice) { + // Since we're moving the copy over a Slice/DynamicSlice, we need to + // change the shape of the copy to match the shape of the result of the + // Slice/DynamicSlice. We want to maintain the layout of + // shape_after_copy though. + Shape new_copy_shape = instruction->shape(); + *new_copy_shape.mutable_layout() = shape_after_copy.layout(); + VLOG(2) << absl::StreamFormat( + " Encountered %s \"%s\", updating shape after copy from " + "%s to %s", + HloOpcodeString(instruction->opcode()), instruction->name(), + shape_after_copy.ToString(true), new_copy_shape.ToString(true)); + shape_after_copy = new_copy_shape; + } + + // Update the shape of this instruction as if the copy never happened. UpdateInstructionLayout(instruction_and_index, - copy_to_move->operand(0)->shape().layout()); + shape_before_copy.layout()); if (instruction->opcode() == HloOpcode::kParameter) { + // Also update the layout of the call site. std::vector callers = call_graph->GetComputationCallers(instruction->parent()); if (callers.size() != 1) { @@ -422,11 +503,12 @@ absl::Status MoveCopy( } HloInstruction* caller = callers[0]; UpdateInstructionLayout(InstructionAndIndex(caller, index), - copy_to_move->operand(0)->shape().layout()); + shape_before_copy.layout()); } CHECK_NE(instruction->opcode(), HloOpcode::kCopy) - << "Copies should be processed in order"; + << "Copies should be processed in reverse order so this never " + "happens"; if (absl::c_linear_search(kUsersOpcodes, instruction->opcode()) || instruction->IsCustomCall( host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { @@ -444,12 +526,12 @@ absl::Status MoveCopy( instruction->shape(), {instruction})); } UpdateInstructionLayout(InstructionAndIndex(new_annotation, -1), - copy_to_move->operand(0)->shape().layout()); - Shape new_copy_shape = new_annotation->shape(); - *new_copy_shape.mutable_layout() = copy_to_move->shape().layout(); + shape_before_copy.layout()); + VLOG(3) << absl::StreamFormat("Creating copy with shape %s", + shape_after_copy.ToString(true)); HloInstruction* new_copy = instruction->AddInstruction(copy_to_move->CloneWithNewOperands( - new_copy_shape, {new_annotation})); + shape_after_copy, {new_annotation})); VLOG(2) << absl::StreamFormat("Inserting copy \"%s\" after \"%s\"", new_copy->name(), instruction->name()); std::vector users = instruction->users(); @@ -504,7 +586,8 @@ absl::Status MoveCopy( TF_RETURN_IF_ERROR(update_slice->ReplaceOperandWith(0, new_copy)); } } - stack.push_back(instruction_and_index); + stack.emplace_back(instruction_and_index, shape_before_copy, + shape_after_copy); } } VLOG(2) << absl::StreamFormat("Removing copy \"%s\"", diff --git a/xla/service/host_offload_legalize_test.cc b/xla/service/host_offload_legalize_test.cc index df1f06ddba74ec..096f9a10560b44 100644 --- a/xla/service/host_offload_legalize_test.cc +++ b/xla/service/host_offload_legalize_test.cc @@ -117,6 +117,7 @@ ENTRY main.24 { XLA_VLOG_LINES(1, module->ToString()); HloInstruction* custom_call = FindInstruction(module.get(), "custom-call.18"); + ASSERT_NE(custom_call, nullptr); EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); EXPECT_EQ(custom_call->shape().layout(), LayoutUtil::MakeLayout({0, 1})); EXPECT_EQ(custom_call->users()[0]->shape().layout(), @@ -163,12 +164,14 @@ ENTRY main.24 { XLA_VLOG_LINES(1, module->ToString()); HloInstruction* custom_call = FindInstruction(module.get(), "custom-call.18"); + ASSERT_NE(custom_call, nullptr); EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); EXPECT_EQ(custom_call->shape().layout(), LayoutUtil::MakeLayout({0, 1})); EXPECT_EQ(custom_call->users()[0]->shape().layout(), LayoutUtil::MakeLayout({1, 0})); custom_call = FindInstruction(module.get(), "custom-call.19"); + ASSERT_NE(custom_call, nullptr); EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); EXPECT_EQ(custom_call->shape().layout(), LayoutUtil::MakeLayout({0, 1}, {}, {}, {}, {Tile{{8, 128}}})); @@ -206,6 +209,7 @@ ENTRY main.24 { HloInstruction* dus = FindInstruction(module.get(), "dynamic-update-slice.6830"); + ASSERT_NE(dus, nullptr); EXPECT_EQ(dus->operand(0)->shape().layout(), dus->operand(1)->shape().layout()); EXPECT_EQ(dus->shape().layout(), dus->operand(1)->shape().layout()); @@ -251,6 +255,7 @@ ENTRY main.24 { HloInstruction* dus = FindInstruction(module.get(), "dynamic-update-slice.6830"); + ASSERT_NE(dus, nullptr); EXPECT_EQ(dus->operand(0)->shape().layout(), dus->operand(1)->shape().layout()); EXPECT_EQ(dus->shape().layout(), dus->operand(1)->shape().layout()); @@ -369,6 +374,8 @@ ENTRY main { HloInstruction* copy = FindInstruction(module.get(), HloOpcode::kCopy); HloInstruction* consuming_while = FindInstruction(module.get(), "consuming_while"); + ASSERT_NE(copy, nullptr); + ASSERT_NE(consuming_while, nullptr); EXPECT_NE(copy, nullptr); EXPECT_NE(consuming_while, nullptr); EXPECT_EQ(copy->parent(), consuming_while->while_body()); @@ -480,6 +487,9 @@ ENTRY main { HloInstruction* copy_1 = FindInstruction(module.get(), "cp1.2"); HloInstruction* consuming_while = FindInstruction(module.get(), "consuming_while"); + ASSERT_NE(copy_0, nullptr); + ASSERT_NE(copy_1, nullptr); + ASSERT_NE(consuming_while, nullptr); EXPECT_NE(copy_0, nullptr); EXPECT_NE(copy_1, nullptr); EXPECT_NE(consuming_while, nullptr); @@ -488,6 +498,36 @@ ENTRY main { XLA_VLOG_LINES(1, module->ToString()); } +TEST_F(HostOffloadLegalizeTest, MoveCopyOverBitcast) { + const std::string& hlo_string = R"( +HloModule jit_f, entry_computation_layout={(bf16[1,1,16384,4,256]{4,3,2,1,0:T(4,128)(2,1)S(5)})->bf16[1,16384,4,256]{3,1,2,0:T(8,128)(2,1)}} + +ENTRY main { + param = bf16[1,1,16384,4,256]{4,3,2,1,0:T(4,128)(2,1)} parameter(0) + copy = bf16[1,1,16384,4,256]{4,2,3,1,0:T(8,128)(2,1)} copy(param) + bitcast = bf16[1,16384,4,256]{3,1,2,0:T(8,128)(2,1)} bitcast(copy) + custom-call = bf16[1,16384,4,256]{3,1,2,0:T(8,128)(2,1)} custom-call(bitcast), custom_call_target="MoveToDevice" + ROOT add = bf16[1,16384,4,256]{3,1,2,0:T(8,128)(2,1)} add(custom-call, custom-call) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + HloInstruction* custom_call = FindInstruction(module.get(), "custom-call"); + EXPECT_EQ(custom_call->shape().layout(), + LayoutUtil::MakeLayout({3, 2, 1, 0}, {}, {}, {}, + {Tile{{4, 128}}, Tile{{2, 1}}})); + EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); + EXPECT_EQ(custom_call->users()[0]->shape().layout(), + LayoutUtil::MakeLayout({3, 1, 2, 0}, {}, {}, {}, + {Tile{{8, 128}}, Tile{{2, 1}}})); +} + } // namespace } // namespace xla From 2149d629364bd9943c6df91e7250f92f2b267736 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Jul 2024 15:57:27 -0700 Subject: [PATCH 214/376] Simplify and generalize the strategy generation code for convolution ops. Rather than explicitly generating strategies corresponding to different sets of dimensions being sharded, we now generate strategies in a more principled and general manner. PiperOrigin-RevId: 656554274 --- .../auto_sharding_dot_handler.cc | 225 +++++++++--------- 1 file changed, 110 insertions(+), 115 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index dbefe93f8fd04f..3c62712ab41e27 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -302,6 +302,9 @@ class ConvHandler : public HandlerBase { void SplitDepthwise(bool forward); + void GenerateConvolutionShardingStrategiesFromOutputSharding( + const DimMap& output_dim_map); + absl::Status RegisterStrategies(); // Dimension information @@ -875,6 +878,73 @@ ConvHandler::ConvHandler(std::unique_ptr& strategy_group, out_out_channel_dim_ = conv_dnums_.output_feature_dimension(); } +void ConvHandler::GenerateConvolutionShardingStrategiesFromOutputSharding( + const DimMap& output_dim_map) { + DimMap lhs_dim_map; + DimMap rhs_dim_map; + absl::flat_hash_set used_mesh_dims; + std::string name; + + // Propagate batch dim sharding + auto it = output_dim_map.find(out_batch_dim_); + if (it != output_dim_map.end() && device_mesh_.dim(it->second) > 1) { + int mesh_dim = it->second; + lhs_dim_map[lhs_batch_dim_] = mesh_dim; + used_mesh_dims.insert(mesh_dim); + absl::StrAppend(&name, "b", mesh_dim); + } else { + absl::StrAppend(&name, "b-1"); + } + + // Propagate out channel dim sharding + it = output_dim_map.find(out_out_channel_dim_); + if (it != output_dim_map.end() && device_mesh_.dim(it->second) > 1) { + int mesh_dim = it->second; + lhs_dim_map[rhs_out_channel_dim_] = mesh_dim; + used_mesh_dims.insert(mesh_dim); + absl::StrAppend(&name, "oc", mesh_dim); + } else { + absl::StrAppend(&name, "oc-1"); + } + + MaybeAppend(name, lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_); + + // Generate shardings for contraction dimensions + if (used_mesh_dims.size() == device_mesh_.num_dimensions()) { + return; + } + + absl::flat_hash_set unused_mesh_dims; + for (size_t i = 0; i < device_mesh_.num_dimensions(); ++i) { + if (!used_mesh_dims.contains(i) && device_mesh_.dim(i) > 1) { + unused_mesh_dims.insert(i); + } + } + + if (unused_mesh_dims.empty()) { + return; + } + + for (int64_t mesh_dim : unused_mesh_dims) { + DimMap lhs_dim_map_with_contractions = lhs_dim_map; + DimMap rhs_dim_map_with_contractions = rhs_dim_map; + + lhs_dim_map_with_contractions[lhs_in_channel_dim_] = mesh_dim; + rhs_dim_map_with_contractions[rhs_in_channel_dim_] = mesh_dim; + absl::StrAppend(&name, "ic", mesh_dim, "@allreduce"); + + auto communication_cost_fn = [&](const HloSharding& output_sharding) { + return cluster_env_.AllReduceCost( + ByteSizeOfShapeWithSharding(ins_->shape(), output_sharding), + mesh_dim); + }; + + MaybeAppend(name, lhs_dim_map_with_contractions, + rhs_dim_map_with_contractions, output_dim_map, device_mesh_, + /*compute_cost=*/0, communication_cost_fn); + } +} + absl::Status ConvHandler::RegisterStrategies() { // For 1D sharding if ((ins_->feature_group_count() == @@ -895,22 +965,16 @@ absl::Status ConvHandler::RegisterStrategies() { SplitDepthwise(false); } - // SS = SR x RS - // Split lhs batch dim and rhs out_channel dim. - SplitLhsBatchRhsOutchannel(); - - // SR = SS x SR - // Split lhs batch dim and both in_channel dims. - SplitLhsBatchBothInchannel(); - - // RS = RS x SS - // Split rhs out_channel dim and both in_channel dims. - SplitRhsOutchannelBothInchannel(); - - // Add 1d data parallel in multi-dimensional mesh - if (option_.allow_mixed_mesh_shape) { - Add1DDataParallel(); + absl::flat_hash_set all_mesh_dims; + for (int i = 0; i < device_mesh_.num_dimensions(); ++i) { + all_mesh_dims.insert(i); } + EnumerateGeneral( + [&](const DimMap& output_dim_map) { + GenerateConvolutionShardingStrategiesFromOutputSharding(output_dim_map); + }, + 2, /*current_tensor_dim=*/0, all_mesh_dims, + /*current_dim_map=*/{}); // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies // and only keep the data parallel strategies. @@ -920,114 +984,45 @@ absl::Status ConvHandler::RegisterStrategies() { cluster_env_, batch_map_, option_)); } + SortStrategies(); return absl::OkStatus(); } -void ConvHandler::SplitLhsBatchRhsOutchannel() { - auto func = [this](const Enumeration& e) { - const DimMap lhs_dim_map = {{lhs_batch_dim_, e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_out_channel_dim_, e.mesh_dims[1]}}; - std::string name = - absl::StrFormat("SS = SR x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}, - {out_out_channel_dim_, e.mesh_dims[1]}}; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); - }; - EnumerateHalf(func); -} - -void ConvHandler::SplitLhsBatchBothInchannel() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1 || - device_mesh_.dim(e.mesh_dims[1]) <= 1) - return; - const DimMap lhs_dim_map = {{lhs_batch_dim_, e.mesh_dims[0]}, - {lhs_in_channel_dim_, e.mesh_dims[1]}}; - const DimMap rhs_dim_map = {{rhs_in_channel_dim_, e.mesh_dims[1]}}; - std::string name = - absl::StrFormat("SR = SS x SR @ {%s} (allreduce @ %d)", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[1]); - const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}}; - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[1]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - EnumerateHalf(func); -} +void ConvHandler::SplitDepthwise(bool forward) { + std::function split_func = + [&](const DimMap& output_dim_map) { + int out_batch_mesh_dim = -1; + int out_out_channel_mesh_dim = -1; + if (auto it = output_dim_map.find(out_batch_dim_); + it != output_dim_map.end()) { + out_batch_mesh_dim = it->second; + } + if (auto it = output_dim_map.find(out_out_channel_dim_); + it != output_dim_map.end()) { + out_out_channel_mesh_dim = it->second; + } + if (out_batch_mesh_dim == -1 || out_out_channel_mesh_dim == -1) { + return; + } -void ConvHandler::SplitRhsOutchannelBothInchannel() { - auto func = [this](const Enumeration& e) { - if (device_mesh_.dim(e.mesh_dims[0]) <= 1) return; - const DimMap lhs_dim_map = {{lhs_in_channel_dim_, e.mesh_dims[0]}}; - const DimMap rhs_dim_map = {{rhs_in_channel_dim_, e.mesh_dims[0]}, - {rhs_out_channel_dim_, e.mesh_dims[1]}}; - std::string name = - absl::StrFormat("RS = RS x SS @ {%s} (allreduce @ %d)", - absl::StrJoin(e.mesh_dims, ","), e.mesh_dims[0]); - const DimMap out_dim_map = {{out_out_channel_dim_, e.mesh_dims[1]}}; - auto communication_cost_fn = [this, &e](const HloSharding& output_spec) { - double memory_cost = - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec); - return cluster_env_.AllReduceCost(memory_cost, e.mesh_dims[0]); - }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_, 0, - communication_cost_fn); - }; - EnumerateHalf(func); -} + DimMap lhs_dim_map, rhs_dim_map; + lhs_dim_map[lhs_batch_dim_] = + forward ? out_batch_mesh_dim : out_out_channel_mesh_dim; + lhs_dim_map[lhs_in_channel_dim_] = + forward ? out_out_channel_mesh_dim : out_batch_mesh_dim; -void ConvHandler::Add1DDataParallel() { - if (device_mesh_.dim(0) > 1 && - absl::c_count_if(device_mesh_.dimensions(), - [](int64_t size) { return size > 1; }) > 1) { - int mesh_dim = 0; - int64_t num_devices = device_mesh_1d_.dim(mesh_dim); - - // Si = Si x R @ 0 - if (lhs_->shape().dimensions(lhs_batch_dim_) % num_devices == 0) { - const DimMap lhs_dim_map = {{lhs_batch_dim_, mesh_dim}}; - std::string name = absl::StrFormat("Si = Si x R @ 0"); - const DimMap out_dim_map = {{out_batch_dim_, mesh_dim}}; - MaybeAppend(name, lhs_dim_map, {}, out_dim_map, device_mesh_1d_); - } + rhs_dim_map[rhs_out_channel_dim_] = out_out_channel_mesh_dim; - // R = Sk x Sk @ (allreduce @ 0) - if (lhs_->shape().dimensions(lhs_in_channel_dim_) % num_devices == 0 && - rhs_->shape().dimensions(rhs_in_channel_dim_) % num_devices == 0) { - const DimMap lhs_dim_map = {{lhs_in_channel_dim_, mesh_dim}}; - const DimMap rhs_dim_map = {{rhs_in_channel_dim_, mesh_dim}}; - std::string name = absl::StrFormat("R = Sk x Sk @ %d (allreduce @ %d)", - mesh_dim, mesh_dim); - const DimMap out_dim_map = {}; - auto communication_cost_fn = [this](const HloSharding& output_spec) { - double memory_cost = - ByteSizeOfShapeWithSharding(ins_->shape(), output_spec); - return cluster_env_.AllReduceCost(memory_cost, 0) + - cluster_env_.AllReduceCost(memory_cost, 1); + MaybeAppend(absl::StrCat("b", out_batch_mesh_dim, "oc", + out_out_channel_mesh_dim, "@depthwise"), + lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_); }; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_1d_, - 0, communication_cost_fn); - } + absl::flat_hash_set all_mesh_dims; + for (int i = 0; i < device_mesh_.num_dimensions(); ++i) { + all_mesh_dims.insert(i); } -} - -void ConvHandler::SplitDepthwise(bool forward) { - auto func = [this, forward](const Enumeration& e) { - const DimMap lhs_dim_map = { - {lhs_batch_dim_, e.mesh_dims[forward ? 0 : 1]}, - {lhs_in_channel_dim_, e.mesh_dims[forward ? 1 : 0]}}; - const DimMap rhs_dim_map = {{rhs_out_channel_dim_, e.mesh_dims[1]}}; - std::string name = - absl::StrFormat("SS = SS x RS @ {%s}", absl::StrJoin(e.mesh_dims, ",")); - const DimMap out_dim_map = {{out_batch_dim_, e.mesh_dims[0]}, - {out_out_channel_dim_, e.mesh_dims[1]}}; - MaybeAppend(name, lhs_dim_map, rhs_dim_map, out_dim_map, device_mesh_); - }; - EnumerateHalf(func); + EnumerateGeneral(split_func, 2, /*current_tensor_dim=*/0, all_mesh_dims, + /*current_dim_map=*/{}); } } // namespace From 3f3fccc498b17aa681a9fdd080142e55ca8dcbc1 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Fri, 26 Jul 2024 18:29:08 -0700 Subject: [PATCH 215/376] Remove unused header from `tsl/platform/cloud/http_request_fake.h` PiperOrigin-RevId: 656590570 --- third_party/tsl/tsl/platform/cloud/curl_http_request.h | 3 ++- third_party/tsl/tsl/platform/cloud/http_request_fake.h | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tsl/tsl/platform/cloud/curl_http_request.h b/third_party/tsl/tsl/platform/cloud/curl_http_request.h index 490e762967f182..b5c728520dc693 100644 --- a/third_party/tsl/tsl/platform/cloud/curl_http_request.h +++ b/third_party/tsl/tsl/platform/cloud/curl_http_request.h @@ -16,11 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_CLOUD_CURL_HTTP_REQUEST_H_ #define TENSORFLOW_TSL_PLATFORM_CLOUD_CURL_HTTP_REQUEST_H_ +#include + #include #include #include -#include #include "tsl/platform/cloud/http_request.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/tsl/tsl/platform/cloud/http_request_fake.h b/third_party/tsl/tsl/platform/cloud/http_request_fake.h index c3e0670d6e5a39..ea1f487516795e 100644 --- a/third_party/tsl/tsl/platform/cloud/http_request_fake.h +++ b/third_party/tsl/tsl/platform/cloud/http_request_fake.h @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/cloud/curl_http_request.h" #include "tsl/platform/errors.h" From 2428f121a23bbf6564badc3351a9519f139b64e4 Mon Sep 17 00:00:00 2001 From: Amit Sabne Date: Fri, 26 Jul 2024 21:52:21 -0700 Subject: [PATCH 216/376] Reverts 8bf16b4753fa0d6aae3f809a080b167547bf0074 PiperOrigin-RevId: 656643545 --- xla/service/space_to_batch_converter.cc | 489 +++---------------- xla/service/space_to_batch_converter_test.cc | 121 +---- 2 files changed, 82 insertions(+), 528 deletions(-) diff --git a/xla/service/space_to_batch_converter.cc b/xla/service/space_to_batch_converter.cc index ca5ca589dfc1c0..751b6d11dc979c 100644 --- a/xla/service/space_to_batch_converter.cc +++ b/xla/service/space_to_batch_converter.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -904,16 +903,6 @@ ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape( return activations_new; } -bool IsTrivialElementwise(HloInstruction* hlo) { - if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng || - hlo->opcode() == HloOpcode::kCopy || - hlo->opcode() == HloOpcode::kConstant || - hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) { - return false; - } - return hlo->IsElementwise(); -} - absl::StatusOr ConvolutionVisitor::Run() { for (auto conv : conv_visitor_list_) { // If we expect to see an unpropagatable op, space-to-batch may not be @@ -973,6 +962,16 @@ absl::StatusOr ConvolutionVisitor::Run() { return changed_; } +bool IsTrivialElementwise(HloInstruction* hlo) { + if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng || + hlo->opcode() == HloOpcode::kCopy || + hlo->opcode() == HloOpcode::kConstant || + hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) { + return false; + } + return hlo->IsElementwise(); +} + bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, HloInstruction* producer) { if (IsTrivialElementwise(consumer)) { @@ -980,46 +979,13 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, << consumer->ToString(); HloInstruction* pivot_operand = nullptr; - - std::vector operand_iteration_order(consumer->operand_count()); - absl::c_iota(operand_iteration_order, 0); - int64_t next_operand_to_check = 0; - int64_t last_known_good_operand = 0; - const int64_t operand_count = consumer->operand_count(); - while (next_operand_to_check < operand_count && - last_known_good_operand < operand_count) { - if (consumer->operand(next_operand_to_check)->opcode() == - HloOpcode::kBroadcast || - consumer->operand(next_operand_to_check)->opcode() == - HloOpcode::kConstant) { - while (last_known_good_operand < operand_count && - (consumer->operand(last_known_good_operand)->opcode() == - HloOpcode::kBroadcast || - consumer->operand(last_known_good_operand)->opcode() == - HloOpcode::kConstant)) { - last_known_good_operand++; - } - if (last_known_good_operand < operand_count) { - int64_t temp = operand_iteration_order[last_known_good_operand]; - operand_iteration_order[last_known_good_operand] = - next_operand_to_check; - operand_iteration_order[next_operand_to_check] = temp; - last_known_good_operand++; - } - } - next_operand_to_check++; - if (last_known_good_operand < next_operand_to_check) { - last_known_good_operand = next_operand_to_check; - } - } - for (int64_t i = 0; i < consumer->operand_count(); ++i) { - auto old_producer = consumer->mutable_operand(operand_iteration_order[i]); + auto old_producer = consumer->mutable_operand(i); std::vector to_transform; const bool broadcast_or_constant = (old_producer->opcode() == HloOpcode::kConstant) || - (old_producer->opcode() == HloOpcode::kBroadcast && pivot_operand && - IsBroadcastPropagatable(old_producer, pivot_operand)) || + (old_producer->opcode() == HloOpcode::kBroadcast && + IsBroadcastPropagatable(old_producer, producer)) || (consumer->IsElementwiseBinary() && old_producer->opcode() == HloOpcode::kBroadcast && IsBroadcastTree(old_producer, producer, to_transform)); @@ -1088,47 +1054,6 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, } } - if (consumer->opcode() == HloOpcode::kBroadcast) { - // Restrict broadcast propagation. Can be extended in future. - if (consumer->user_count() != 1) { - return false; - } - - HloInstruction* user = consumer->users()[0]; - HloInstruction* pivot = nullptr; - for (auto op : user->operands()) { - if (op != consumer) { - if (!old_to_new_instrs_.contains(op)) { - VLOG(3) << "Cannot propagate on broadcast because op wasn't " - "space-to-batched " - << op->ToString(); - return false; - } - - if (pivot == nullptr) { - pivot = op; - continue; - } - - // All operands should have the same shape. - if (old_to_new_instrs_[pivot]->shape() != - old_to_new_instrs_[op]->shape()) { - VLOG(3) << "Cannot propagate on broadcast because pivot shape didn't " - "match " - << pivot->ToString(); - return false; - } - } - } - if (pivot == nullptr) { - return false; - } - VLOG(1) << "Checking if we could propagate on broadcast " - << consumer->ToString() << " pivot " << pivot->ToString(); - - return IsBroadcastPropagatable(consumer, pivot); - } - if (consumer->opcode() == HloOpcode::kConcatenate) { // Make sure all operands have been space-to-batched. for (int64_t i = 0; i < consumer->operand_count(); ++i) { @@ -1518,23 +1443,6 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, return true; } -// Ensures that broadcast is creating dimensions before/after the operand -// dimensions. -bool BroadcastMovesAllDimensionsLinearly(HloInstruction* broadcast) { - auto broadcast_operand = broadcast->mutable_operand(0); - bool all_dims_broadcasted_linearly = - (broadcast->dimensions().size() == broadcast_operand->shape().rank()); - if (all_dims_broadcasted_linearly) { - for (int i = 1; i < broadcast->dimensions().size(); ++i) { - if (broadcast->dimensions(i) != broadcast->dimensions(i - 1) + 1) { - all_dims_broadcasted_linearly = false; - break; - } - } - } - return all_dims_broadcasted_linearly; -} - void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer, HloInstruction* producer) { auto new_producer = old_to_new_instrs_[producer]; @@ -1545,15 +1453,11 @@ void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer, dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)]; const int64_t old_space_dim = dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)]; - const int64_t old_feature_dim = - dim_map_val[DimMapper(SpaceToBatchDimMap::kFeature)]; auto orig_broadcast_dims = consumer->dimensions(); bool batch_is_broadcasted = absl::c_linear_search(orig_broadcast_dims, old_batch_dim); - bool space_is_broadcasted = - absl::c_linear_search(orig_broadcast_dims, old_space_dim); const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim); const int64_t new_space_dim = DimLookUp(permute_dims, old_space_dim); @@ -1568,82 +1472,31 @@ void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer, } } - HloInstruction* new_broadcast = nullptr; - if (space_is_broadcasted && batch_is_broadcasted) { - CHECK(BroadcastMovesAllDimensionsLinearly(consumer)); - // Note in this case the new_producer was a pivot, not the actual operand. - new_producer = old_to_new_instrs_[consumer->mutable_operand(0)]; - - const auto& dimensions = consumer->dimensions(); - const int64_t output_rank = consumer->shape().rank(); - const int64_t starting_offset = dimensions[0]; - std::vector new_broadcast_dims(dimensions.size()); - for (int i = 0; i < dimensions.size(); ++i) { - new_broadcast_dims[i] = starting_offset + i; - } - - // Start with original broadcast shape. - std::vector final_shape_dims( - consumer->shape().dimensions().begin(), - consumer->shape().dimensions().end()); - - for (int i = 0; i < dimensions.size(); ++i) { - final_shape_dims[starting_offset + i] = - new_producer->shape().dimensions(i); - } - - new_broadcast = MakeBroadcastHlo(new_producer, new_broadcast_dims, - final_shape_dims, &consumer->metadata(), - &consumer->frontend_attributes()); - - std::vector new_permute_dims(output_rank); - absl::c_iota(new_permute_dims, 0); - - for (int i = 0; i < dimensions.size(); ++i) { - new_permute_dims[starting_offset + i] = permute_dims[i] + starting_offset; - } - - std::vector dim_map(kNumMappedDims); - dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] = - old_batch_dim + starting_offset; - dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] = - old_space_dim + starting_offset; - dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] = - old_feature_dim + starting_offset; - - instr_to_dim_map_[consumer] = dim_map; - - instr_to_dim_permute_map_[new_broadcast] = new_permute_dims; - old_to_new_instrs_[consumer] = new_broadcast; + std::vector final_shape_dims( + new_producer->shape().dimensions().begin(), + new_producer->shape().dimensions().end()); + if (batch_is_broadcasted) { + final_shape_dims[new_batch_dim] = + producer->shape().dimensions(old_batch_dim); + final_shape_dims[new_space_dim] *= ctrl_.number_of_splits; + } - } else { - std::vector final_shape_dims( - new_producer->shape().dimensions().begin(), - new_producer->shape().dimensions().end()); - - if (batch_is_broadcasted) { - final_shape_dims[new_batch_dim] = - producer->shape().dimensions(old_batch_dim); - final_shape_dims[new_space_dim] *= ctrl_.number_of_splits; - } + std::vector broadcast_dims; + const auto& dimensions = consumer->dimensions(); + broadcast_dims.reserve(dimensions.size()); + for (auto j : dimensions) { + broadcast_dims.push_back(DimLookUp(permute_dims, j)); + } + auto new_broadcast = MakeBroadcastHlo( + consumer->mutable_operand(0), broadcast_dims, final_shape_dims, + &consumer->metadata(), &consumer->frontend_attributes()); + VLOG(1) << "Created broadcast " << new_broadcast->ToString(); - std::vector broadcast_dims; - const auto& dimensions = consumer->dimensions(); - broadcast_dims.reserve(dimensions.size()); - for (auto j : dimensions) { - broadcast_dims.push_back(DimLookUp(permute_dims, j)); - } - new_broadcast = MakeBroadcastHlo( - consumer->mutable_operand(0), broadcast_dims, final_shape_dims, - &consumer->metadata(), &consumer->frontend_attributes()); - VLOG(3) << "Created broadcast " << new_broadcast->ToString(); - - if (batch_is_broadcasted) { - new_broadcast = - MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast) - .value(); - VLOG(3) << "Created reshape of broadcast " << new_broadcast->ToString(); - } + if (batch_is_broadcasted) { + new_broadcast = + MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast) + .value(); + VLOG(2) << "Created reshape of broadcast " << new_broadcast->ToString(); } if (!map_found) { @@ -1697,64 +1550,15 @@ bool ConvolutionVisitor::IsBroadcastTree( return true; } -// old_other_op is the other operand of the elementwise op, apart from the -// broadcast. bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast, HloInstruction* old_other_op) { - CHECK_NE(old_other_op, nullptr); CHECK_EQ(broadcast->opcode(), HloOpcode::kBroadcast); - - CHECK(instr_to_dim_map_.contains(old_other_op)) - << "old_other_op " << old_other_op->ToString(); + CHECK(instr_to_dim_map_.contains(old_other_op)); auto result = instr_to_dim_map_[old_other_op]; - const int64_t old_space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)]; - const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)]; - - auto broadcast_operand = broadcast->mutable_operand(0); + const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)]; auto broadcast_dims = broadcast->dimensions(); - - if (!broadcast_dims.empty() && - absl::c_linear_search(broadcast_dims, old_space_dim) && - absl::c_linear_search(broadcast_dims, old_batch_dim) && - old_to_new_instrs_.contains(broadcast_operand) && - old_to_new_instrs_.contains(old_other_op)) { - auto new_broadcast_operand = old_to_new_instrs_[broadcast_operand]; - auto new_other_op = old_to_new_instrs_[old_other_op]; - - auto broadcast_operand_dim_map = instr_to_dim_map_[broadcast_operand]; - const int64_t old_broadcast_operand_space_dim = - broadcast_operand_dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)]; - const int64_t old_broadcast_operand_batch_dim = - broadcast_operand_dim_map[DimMapper(SpaceToBatchDimMap::kBatch)]; - - auto permute_dims_broadcast_operand = - instr_to_dim_permute_map_[new_broadcast_operand]; - auto permute_dims_other_op = instr_to_dim_permute_map_[new_other_op]; - const int64_t new_batch_dim_broadcast_operand = DimLookUp( - permute_dims_broadcast_operand, old_broadcast_operand_batch_dim); - const int64_t new_space_dim_broadcast_operand = DimLookUp( - permute_dims_broadcast_operand, old_broadcast_operand_space_dim); - const int64_t new_batch_dim_other_op = - DimLookUp(permute_dims_other_op, old_batch_dim); - const int64_t new_space_dim_other_op = - DimLookUp(permute_dims_other_op, old_space_dim); - - if (BroadcastMovesAllDimensionsLinearly(broadcast) && - new_broadcast_operand->shape().dimensions( - new_batch_dim_broadcast_operand) == - new_other_op->shape().dimensions(new_batch_dim_other_op) && - new_broadcast_operand->shape().dimensions( - new_space_dim_broadcast_operand) == - new_other_op->shape().dimensions(new_space_dim_other_op)) { - VLOG(3) << "Broadcast on both space and batch dims"; - return true; - } - } - - // Do no allow broadcast with space dim staying intact. This effectively - // catches scalar broadcasts at this point. - return !absl::c_linear_search(broadcast_dims, old_space_dim); + return !absl::c_linear_search(broadcast_dims, space_dim); } bool ConvolutionVisitor::IsOpcodeNonPropagatable(HloInstruction* consumer) { @@ -1821,8 +1625,7 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, for (int64_t i = 0; i < consumer->operand_count(); ++i) { if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) { if (!IsBroadcastPropagatable(consumer->mutable_operand(i), producer)) { - VLOG(2) << "Could not propagate through broadcast while checking on " - "elementwise op"; + VLOG(2) << "Could not propagate through broadcast"; return false; } } @@ -1834,10 +1637,6 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, return true; } - if (consumer->opcode() == HloOpcode::kBroadcast) { - return true; - } - if (consumer->opcode() == HloOpcode::kConcatenate) { HloInstruction* pivot_operand = nullptr; for (int64_t i = 0; i < consumer->operand_count(); ++i) { @@ -1935,10 +1734,6 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, } if (consumer->opcode() == HloOpcode::kReduce) { - // Do not propagate through tuple outputs. - if (consumer->shape().IsTuple()) { - return false; - } // Support only the trivial case where both batch and split spatial dim are // being reduced @@ -1946,13 +1741,8 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, auto result = instr_to_dim_map_[consumer->mutable_operand(0)]; const int64_t batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)]; const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)]; - // Support the trivial case where none of the batch and split spatial dim - // are being reduced. - return !absl::c_linear_search(reduce_dims, batch_dim) && - !absl::c_linear_search(reduce_dims, space_dim); - - // Support only the trivial case where both batch and split spatial dim are - // being reduced + VLOG(1) << "Checking if reduce is supported batch_dim " << batch_dim + << " space_dim " << space_dim << " reduce " << consumer->ToString(); return absl::c_linear_search(reduce_dims, batch_dim) && absl::c_linear_search(reduce_dims, space_dim); } @@ -2215,11 +2005,6 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, } } - if (consumer->opcode() == HloOpcode::kBroadcast) { - PropagateOnBroadcast(consumer, producer); - return true; - } - if (consumer->opcode() == HloOpcode::kConcatenate) { TF_CHECK_OK(PropagateOnConcat(consumer)); return true; @@ -2287,116 +2072,16 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, } if (consumer->opcode() == HloOpcode::kReduce) { - auto reduce_dims = consumer->dimensions(); - auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)]; + auto new_consumer = computation->AddInstruction(consumer->Clone()); auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)]; - auto permute_dims = instr_to_dim_permute_map_[first_operand]; + auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)]; const int64_t old_batch_dim = dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)]; - const int64_t space_dim = - dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)]; + auto permute_dims = instr_to_dim_permute_map_[first_operand]; const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim); - const int64_t new_space_dim = DimLookUp(permute_dims, space_dim); - std::vector changed_dims(consumer->dimensions().size()); - - // Support the trivial case where none of the batch and split spatial dim - // are being reduced. - if (!absl::c_linear_search(reduce_dims, old_batch_dim) && - !absl::c_linear_search(reduce_dims, space_dim)) { - for (int64_t i = 0; i < consumer->dimensions().size(); ++i) { - changed_dims[i] = DimLookUp(permute_dims, consumer->dimensions(i)); - } - // Decide where the new batch and space dims are in the output. - int64_t new_output_batch_dim = new_batch_dim; - int64_t new_output_space_dim = new_space_dim; - for (int64_t i = 0; i < consumer->dimensions().size(); ++i) { - if (changed_dims[i] < new_batch_dim) { - new_output_batch_dim--; - } - if (changed_dims[i] < new_space_dim) { - new_output_space_dim--; - } - } - - // Decide where the new batch and space dims are in the original reduce's - // output. - int64_t old_output_batch_dim = old_batch_dim; - int64_t old_output_space_dim = space_dim; - for (int64_t i = 0; i < consumer->dimensions().size(); ++i) { - if (reduce_dims[i] < old_batch_dim) { - old_output_batch_dim--; - } - if (reduce_dims[i] < space_dim) { - old_output_space_dim--; - } - } - - HloInstruction* new_consumer = nullptr; - TF_ASSIGN_OR_RETURN( - new_consumer, - MakeReduceHlo(first_operand, consumer->mutable_operand(1), - changed_dims, consumer->called_computations()[0])); - - VLOG(3) << " new_output_batch_dim " << new_output_batch_dim << " size " - << first_operand->shape().dimensions(new_batch_dim) - << " new_output_space_dim " << new_output_space_dim << " size " - << first_operand->shape().dimensions(new_space_dim); - - std::vector dim_map(kNumMappedDims); - dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] = old_output_batch_dim; - dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] = old_output_space_dim; - // We don't know where the feature dim is, so set it to -1. - dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] = -1; - - instr_to_dim_map_[consumer] = dim_map; - const int64_t rank = first_operand->shape().rank(); - - const int64_t output_rank = new_consumer->shape().rank(); - - // Make a map of each dim in original reduce output to input. - std::vector old_reduce_output_to_input(output_rank); - int dim_number_to_assign_old = 0; - for (int64_t i = 0; i < rank; ++i) { - if (auto it = absl::c_find(reduce_dims, i); it != reduce_dims.end()) { - continue; - } - old_reduce_output_to_input[i] = dim_number_to_assign_old++; - } - - // Make a map of each dim in new reduce output to the new input. - std::vector new_reduce_output_to_input(output_rank); - int dim_number_to_assign_new = 0; - for (int64_t i = 0; i < rank; ++i) { - if (auto it = absl::c_find(changed_dims, i); it != changed_dims.end()) { - continue; - } - new_reduce_output_to_input[i] = dim_number_to_assign_new++; - } - - std::vector new_permute_dims(output_rank); - // From the output dims to input dims mapping, figure how the old output - // dims are mapped to the new output dims. - for (int64_t i = 0; i < output_rank; ++i) { - new_permute_dims[i] = std::distance( - new_reduce_output_to_input.begin(), - absl::c_find( - new_reduce_output_to_input, - DimLookUp(permute_dims, old_reduce_output_to_input[i]))); - } - - instr_to_dim_permute_map_[new_consumer] = new_permute_dims; - old_to_new_instrs_[consumer] = new_consumer; - - // Because batch and split spatial dims are not reduced, further - // propagation is needed. - return true; - } - - HloInstruction* new_consumer = - computation->AddInstruction(consumer->Clone()); auto retval = GetSpatialDimsToSplit(consumer->mutable_operand(0)); std::vector old_spatial_dims = retval.first; std::vector new_spatial_dims = retval.second; @@ -2407,6 +2092,7 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, consumer->mutable_operand(1), new_batch_dim, new_spatial_dims, old_batch_dim, old_spatial_dims)); + std::vector changed_dims(new_consumer->dimensions().size()); for (int64_t i = 0; i < new_consumer->dimensions().size(); ++i) { changed_dims[i] = DimLookUp(permute_dims, new_consumer->dimensions(i)); } @@ -2876,7 +2562,6 @@ absl::StatusOr ConvolutionVisitor::SelectValidPortion( absl::StatusOr ConvolutionVisitor::BatchToSpace( HloInstruction* old_instr) { - VLOG(1) << "Batch to space on " << old_instr->ToString(); if (batch_to_space_map_.count(old_instr)) { CHECK_NE(batch_to_space_map_[old_instr], nullptr); return batch_to_space_map_[old_instr]; @@ -2952,7 +2637,7 @@ absl::StatusOr ConvolutionVisitor::BatchToSpace( } absl::Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) { - std::deque> propagation_worklist; + std::queue> propagation_worklist; if (old_conv->user_count() == 0) { TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, @@ -2965,61 +2650,27 @@ absl::Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) { } int64_t iteration_count = 0; - propagation_worklist.push_front( + propagation_worklist.push( std::make_pair(old_conv, old_conv->mutable_operand(0))); while (!propagation_worklist.empty()) { auto top = propagation_worklist.front(); auto node = top.first; auto parent = top.second; - VLOG(1) << "Checking for propagation operating on " << node->ToString(); + VLOG(1) << "Traversing for propagation operating on " << node->ToString(); + propagation_worklist.pop(); - propagation_worklist.pop_front(); - HloInstructionSet unsupported_users; // Don't work on the same node again. if (old_to_new_instrs_.count(node) > 0 && iteration_count != 0) { continue; } - bool needs_further_propagation = true; + bool needs_further_propagation = true; if (iteration_count != 0) { - // If the op is unsupported for propagation, we will do batch-to-space on - // its producer. - if (!SupportedOpForPropagation(node, parent)) { - VLOG(1) << "Unsupported op found " << node->ToString() << " producer " - << node->ToString(); - unsupported_users.insert(node); - - if (!unsupported_users.empty()) { - TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, - BatchToSpace(parent)); - for (auto user : unsupported_users) { - for (int64_t i = 0; i < user->operand_count(); ++i) { - if (user->operand(i) == parent) { - TF_CHECK_OK(user->ReplaceOperandWith(i, batch_to_space)); - } - } - } - } - continue; - } - - VLOG(3) << "Checking for CanPropagate on " << node->ToString(); - // If the op is not ready, mark as non-propagatable. - if (CanPropagate(node, parent)) { - non_propagatable_instrs_.erase(node); - } else { - VLOG(3) << "Marking user as non-propagatable " << node->ToString(); - non_propagatable_instrs_.insert(node); - continue; - } // Do the space-to-batch propagation on this node. TF_ASSIGN_OR_RETURN(needs_further_propagation, Propagate(node, parent)); } - iteration_count++; - - VLOG(1) << "Traversing for propagation operating on " << node->ToString(); // If this is the root, no room for further propagation. if (node->parent()->root_instruction() == node) { // The below case does not need going back to space. @@ -3042,18 +2693,36 @@ absl::Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) { continue; } - if (node->opcode() == HloOpcode::kReduce) { - VLOG(3) << "Module " << node->parent()->ToString(); - } - - // Insert all users into the queue. - std::vector users_to_visit; + HloInstructionSet unsupported_users; + // Insert all users into the queue, as long as the ops are supported and + // the op is ready for propagation. If the op is unsupported, do + // batch-to-space. If not ready, mark as non-propagatable. for (auto user : node->users()) { - users_to_visit.push_back(user); + if (!SupportedOpForPropagation(user, node)) { + VLOG(1) << "Unsupported op found " << user->ToString(); + unsupported_users.insert(user); + continue; + } + // If the instruction is ready for propagation, add it to the queue. + if (CanPropagate(user, node)) { + non_propagatable_instrs_.erase(user); + propagation_worklist.push(std::make_pair(user, node)); + } else { + // Mark it as non-propagatable for now, for later revisiting. + non_propagatable_instrs_.insert(user); + } } - for (auto it = users_to_visit.rbegin(); it != users_to_visit.rend(); - ++it) { - propagation_worklist.push_front(std::make_pair(*it, node)); + + if (!unsupported_users.empty()) { + TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, + BatchToSpace(node)); + for (auto user : unsupported_users) { + for (int64_t i = 0; i < user->operand_count(); ++i) { + if (user->operand(i) == node) { + TF_CHECK_OK(user->ReplaceOperandWith(i, batch_to_space)); + } + } + } } } } diff --git a/xla/service/space_to_batch_converter_test.cc b/xla/service/space_to_batch_converter_test.cc index 5287727e5643a4..e2ed3314bc4f6f 100644 --- a/xla/service/space_to_batch_converter_test.cc +++ b/xla/service/space_to_batch_converter_test.cc @@ -115,11 +115,11 @@ TEST_F(SpaceToBatchConverterTest, SimpleBatch1WithReduceWindow) { %convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3}, dim_labels=b01f_01io->b01f %constant = bf16[3] constant({1.0, 2.0, 3.0}) - %tuple = (bf16[1,256,256,32], bf16[3]) tuple(%convolution, %constant) - %gte = bf16[1,256,256,32] get-tuple-element(%tuple), index=0 + %tuple = (bf16[1,256,256,32], bf16[3])tuple(%convolution, %constant) + ROOT %gte = bf16[1,256,256,32] get-tuple-element(%tuple), index=0 %gte2 = bf16[3]get-tuple-element(%tuple), index=1 %init = bf16[] constant(1.0) - ROOT %reduce-window = bf16[3] reduce-window(bf16[3] %gte2, bf16[] %init), + %reduce-window = bf16[3] reduce-window(bf16[3] %gte2, bf16[] %init), window={size=1}, to_apply=%adder } @@ -272,120 +272,5 @@ TEST_F(SpaceToBatchConverterTest, PropagateThroughDot) { ASSERT_TRUE(converter.Run(module.get()).value()); } -TEST_F(SpaceToBatchConverterTest, PropagateOnTrivialReduce) { - std::string hlo_string = R"( - HloModule module - - %region_1.37 (Arg_0.38: f32[], Arg_1.39: f32[]) -> f32[] { - %Arg_0.38 = f32[] parameter(0) - %Arg_1.39 = f32[] parameter(1) - ROOT %add.40 = f32[] add(f32[] %Arg_0.38, f32[] %Arg_1.39) - } - - ENTRY computation { - %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0) - %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1) - %c = f32[7,160,400,32]{3,2,1,0} convolution( %p0, %p1), - window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f - %constant.5 = f32[] constant(0) - ROOT %reduce.41 = f32[7,160,400]{2,1,0} reduce(%c, %constant.5), dimensions={3}, to_apply=%region_1.37 - } - )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - - auto computation = module->entry_computation(); - SpaceToBatchConverter converter( - SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8}); - ASSERT_TRUE(converter.Run(module.get()).value()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Transpose()); - EXPECT_THAT(root->operand(0)->operand(0)->operand(0)->operand(0), - op::Reduce()); - auto new_reduce = root->operand(0)->operand(0)->operand(0)->operand(0); - // Make sure we propagated on the reduce with the larger batch size. - EXPECT_EQ(new_reduce->shape().dimensions(1), - // batch*number_of_splits - 7 * 8); -} - -TEST_F(SpaceToBatchConverterTest, DoNotPropagateOnTupleReduce) { - std::string hlo_string = R"( - HloModule module - -%minmax_func.2717 { - %lhs_value.2718 = f32[] parameter(0) - %rhs_value.2720 = f32[] parameter(2) - %compare.2722 = pred[] compare(f32[] %lhs_value.2718, f32[] %rhs_value.2720), direction=GE - %select.2723 = f32[] select(pred[] %compare.2722, f32[] %lhs_value.2718, f32[] %rhs_value.2720) - %compare.2725 = pred[] compare(f32[] %lhs_value.2718, f32[] %rhs_value.2720), direction=EQ - %lhs_index.2719 = f32[] parameter(1) - %rhs_index.2721 = f32[] parameter(3) - %minimum.2726 = f32[] minimum(f32[] %lhs_index.2719, f32[] %rhs_index.2721) - %select.2724 = f32[] select(pred[] %compare.2722, f32[] %lhs_index.2719, f32[] %rhs_index.2721) - %select.2727 = f32[] select(pred[] %compare.2725, f32[] %minimum.2726, f32[] %select.2724) - ROOT %tuple.4 = (f32[], f32[]) tuple(f32[] %select.2723, f32[] %select.2727) - } - - ENTRY computation { - %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0) - %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1) - %c = f32[7,160,400,32]{3,2,1,0} convolution( %p0, %p1), - window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f - %constant.5 = f32[] constant(0) - %constant.6 = f32[] constant(1) - ROOT %reduce.36 = (f32[7,160,400]{2,1,0}, f32[7,160,400]{2,1,0}) reduce(%c, %c, - %constant.5, %constant.6), dimensions={3}, to_apply=%minmax_func.2717 - } - )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - - auto computation = module->entry_computation(); - SpaceToBatchConverter converter( - SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8}); - ASSERT_TRUE(converter.Run(module.get()).value()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Reduce()); -} - -TEST_F(SpaceToBatchConverterTest, PropagateOnBroadcast) { - std::string hlo_string = R"( - HloModule module - - %region_1.37 (Arg_0.38: f32[], Arg_1.39: f32[]) -> f32[] { - %Arg_0.38 = f32[] parameter(0) - %Arg_1.39 = f32[] parameter(1) - ROOT %add.40 = f32[] add(f32[] %Arg_0.38, f32[] %Arg_1.39) - } - - ENTRY computation { - %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0) - %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1) - %c = bf16[7,160,400,32]{3,2,1,0} convolution( %p0, %p1), - window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f - %constant.5 = f32[] constant(0) - %convert.29 = f32[7,160,400,32]{3,2,1,0} convert(%c) - %reduce.41 = f32[7,160,400]{2,1,0} reduce(%convert.29, %constant.5), dimensions={3}, to_apply=%region_1.37 - %broadcast.51 = f32[7,160,400,32]{3,2,1,0} broadcast(f32[7,160,400]{2,1,0} %reduce.41), dimensions={0,1,2} - ROOT %subtract.52 = f32[7,160,400,32]{3,2,1,0} subtract(%c, %broadcast.51) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - - SpaceToBatchConverter converter( - SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8}); - ASSERT_TRUE(converter.Run(module.get()).value()); - auto computation = module->entry_computation(); - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Transpose()); - // This means we'd propagated through the subtract. - EXPECT_THAT(root->operand(0)->operand(0)->operand(0)->operand(0), - op::Subtract()); -} - } // namespace } // namespace xla From 6512aaed3b2c533d8ae30eeb1d8b8b66700b55ee Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 26 Jul 2024 21:55:35 -0700 Subject: [PATCH 217/376] [xla:cpu] Do not crash if multiple sort operations share a comparator PiperOrigin-RevId: 656644215 --- xla/service/cpu/ir_emitter2.cc | 6 ++++++ xla/tests/sort_test.cc | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index 70644ed173a15e..eeac5bcc390993 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -546,6 +546,12 @@ absl::StatusOr IrEmitter2::EmitSortComparator( const HloInstruction* instr) { HloComputation* comparator = instr->to_apply(); + // Find if we already emitted this comparator. + auto info = absl::c_find_if(comparators_, [&](const ComparatorInfo& info) { + return info.name == comparator->name(); + }); + if (info != comparators_.end()) return *info; + // We use simple post-order schedule as we are not emitting a "real" // computation that requires buffer assignment. auto schedule = comparator->MakeInstructionPostOrder(); diff --git a/xla/tests/sort_test.cc b/xla/tests/sort_test.cc index d4e18c7891c23e..b832dbdd0df0d5 100644 --- a/xla/tests/sort_test.cc +++ b/xla/tests/sort_test.cc @@ -63,5 +63,27 @@ XLA_TEST_F(SortTest, SortDim1) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); } +XLA_TEST_F(SortTest, SortTwiceWithSameComparator) { + std::string_view hlo_text_module = R"( + HloModule sort + + compare { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT lt = pred[] compare(p0, p1), direction=LT + } + + ENTRY e { + x = f32[32,64] parameter(0) + y = f32[64,32] parameter(1) + sort_x = f32[32,64] sort(x), dimensions={0}, to_apply=compare + sort_y = f32[64,32] sort(y), dimensions={1}, to_apply=compare + ROOT tuple = (f32[32,64], f32[64,32]) tuple(sort_x, sort_y) + } + )"; + + EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); +} + } // namespace } // namespace xla From b4061343151d937f68b7f90ba53f918f08ab8f99 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 27 Jul 2024 03:56:46 -0700 Subject: [PATCH 218/376] Automated Code Change PiperOrigin-RevId: 656707031 --- xla/service/llvm_ir/BUILD | 44 ++++++++++++++----- xla/service/llvm_ir/alias_analysis.cc | 8 +++- xla/service/llvm_ir/alias_analysis.h | 2 + xla/service/llvm_ir/buffer_assignment_util.cc | 6 +++ xla/service/llvm_ir/buffer_assignment_util.h | 3 ++ .../llvm_ir/dynamic_update_slice_util.cc | 17 +++++++ .../llvm_ir/dynamic_update_slice_util.h | 4 ++ xla/service/llvm_ir/fused_ir_emitter.cc | 7 +-- xla/service/llvm_ir/fused_ir_emitter.h | 1 + xla/service/llvm_ir/ir_array.cc | 10 ++++- xla/service/llvm_ir/ir_array.h | 3 ++ xla/service/llvm_ir/ir_array_test.cc | 3 ++ xla/service/llvm_ir/kernel_support_library.cc | 15 +++++++ xla/service/llvm_ir/kernel_support_library.h | 6 +++ xla/service/llvm_ir/llvm_loop.cc | 10 +++-- xla/service/llvm_ir/llvm_loop.h | 3 ++ xla/service/llvm_ir/llvm_util.cc | 3 ++ xla/service/llvm_ir/loop_emitter.cc | 13 ++++-- xla/service/llvm_ir/loop_emitter.h | 4 ++ xla/service/llvm_ir/math_ops.cc | 5 +++ xla/service/llvm_ir/sort_util.cc | 9 ++-- xla/service/llvm_ir/sort_util.h | 2 + xla/service/llvm_ir/tuple_ops.cc | 9 +++- xla/service/llvm_ir/tuple_ops.h | 1 + 24 files changed, 161 insertions(+), 27 deletions(-) diff --git a/xla/service/llvm_ir/BUILD b/xla/service/llvm_ir/BUILD index c39abe9e8f5adc..0c49fbc816eedc 100644 --- a/xla/service/llvm_ir/BUILD +++ b/xla/service/llvm_ir/BUILD @@ -39,9 +39,11 @@ cc_library( deps = [ ":ir_array", ":llvm_type_conversion_util", + "//xla:shape_util", "//xla:types", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", + "//xla/service:hlo_value", "//xla/service:logical_buffer", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -79,10 +81,12 @@ cc_library( "//xla/service:dump", "//xla/service:hlo_module_config", "//xla/service/cpu:cpu_options", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", @@ -130,11 +134,12 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:status", ], ) @@ -164,15 +169,14 @@ cc_library( ":ir_array", ":llvm_loop", "//xla:shape_util", - "//xla:status_macros", - "//xla:types", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:statusor", ], ) @@ -184,18 +188,16 @@ cc_library( deps = [ ":ir_array", ":llvm_util", - ":tuple_ops", + ":loop_emitter", "//xla:shape_util", - "//xla:status_macros", "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:elemental_ir_emitter", - "//xla/service:fusion_node_indexing_evaluation", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", ], @@ -210,13 +212,22 @@ cc_library( ":ir_array", ":llvm_util", ":loop_emitter", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:elemental_ir_emitter", "//xla/service/cpu:backend_config_proto_cc", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:parallel_loop_emitter", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:ir_headers", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -227,11 +238,11 @@ cc_library( deps = [ ":ir_array", ":kernel_support_library", - ":llvm_loop", ":llvm_util", ":loop_emitter", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:parallel_loop_emitter", "//xla/service/gpu:target_util", @@ -240,6 +251,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", ], ) @@ -253,8 +265,9 @@ cc_library( ":llvm_type_conversion_util", ":llvm_util", "//xla:shape_util", - "//xla:types", "//xla:xla_data_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@tsl//tsl/platform:logging", @@ -269,9 +282,15 @@ cc_library( ":llvm_loop", ":llvm_type_conversion_util", ":llvm_util", + "//xla/service:hlo_module_config", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:status", ], ) @@ -280,8 +299,11 @@ cc_library( srcs = ["buffer_assignment_util.cc"], hdrs = ["buffer_assignment_util.h"], deps = [ + "//xla:literal", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", + "//xla/service:buffer_value", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", ], ) @@ -311,7 +333,9 @@ xla_cc_test( deps = [ ":ir_array", ":llvm_util", + "//xla:shape_util", "//xla:test", + "//xla:xla_data_proto_cc", "//xla/tests:filecheck", "//xla/tests:xla_internal_test_main", "@llvm-project//llvm:Support", diff --git a/xla/service/llvm_ir/alias_analysis.cc b/xla/service/llvm_ir/alias_analysis.cc index d239c7520bd227..40941eea92e46c 100644 --- a/xla/service/llvm_ir/alias_analysis.cc +++ b/xla/service/llvm_ir/alias_analysis.cc @@ -19,9 +19,15 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/IR/MDBuilder.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/hlo_value.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/logical_buffer.h" -#include "xla/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" namespace xla { namespace llvm_ir { diff --git a/xla/service/llvm_ir/alias_analysis.h b/xla/service/llvm_ir/alias_analysis.h index 91d850e61db0e1..7916871778398b 100644 --- a/xla/service/llvm_ir/alias_analysis.h +++ b/xla/service/llvm_ir/alias_analysis.h @@ -18,10 +18,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/buffer_assignment.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape_util.h" #include "xla/types.h" namespace xla { diff --git a/xla/service/llvm_ir/buffer_assignment_util.cc b/xla/service/llvm_ir/buffer_assignment_util.cc index 8f40663f1583da..5e0f85c797df06 100644 --- a/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/xla/service/llvm_ir/buffer_assignment_util.cc @@ -17,8 +17,14 @@ limitations under the License. #include +#include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/buffer_value.h" namespace xla { namespace llvm_ir { diff --git a/xla/service/llvm_ir/buffer_assignment_util.h b/xla/service/llvm_ir/buffer_assignment_util.h index c6d4c89b405faa..ca333a0ff2385a 100644 --- a/xla/service/llvm_ir/buffer_assignment_util.h +++ b/xla/service/llvm_ir/buffer_assignment_util.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ #define XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/literal.h" #include "xla/service/buffer_assignment.h" namespace xla { diff --git a/xla/service/llvm_ir/dynamic_update_slice_util.cc b/xla/service/llvm_ir/dynamic_update_slice_util.cc index 44020a857e7acc..30492cf2fbb738 100644 --- a/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -17,15 +17,32 @@ limitations under the License. #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" +#include "xla/service/buffer_assignment.h" #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/parallel_loop_emitter.h" +#include "xla/service/llvm_ir/fused_ir_emitter.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace llvm_ir { diff --git a/xla/service/llvm_ir/dynamic_update_slice_util.h b/xla/service/llvm_ir/dynamic_update_slice_util.h index ea70465d0b5207..dba763926f21f9 100644 --- a/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -17,6 +17,10 @@ limitations under the License. #define XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ #include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/IRBuilder.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/buffer_assignment.h" #include "xla/service/elemental_ir_emitter.h" diff --git a/xla/service/llvm_ir/fused_ir_emitter.cc b/xla/service/llvm_ir/fused_ir_emitter.cc index 16b10a0eefc977..721bab4082f35c 100644 --- a/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/xla/service/llvm_ir/fused_ir_emitter.cc @@ -21,23 +21,20 @@ limitations under the License. #include "absl/status/statusor.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" #include "llvm/TargetParser/Triple.h" -#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/elemental_ir_emitter.h" -#include "xla/service/fusion_node_indexing_evaluation.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/service/llvm_ir/tuple_ops.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" #include "xla/util.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/llvm_ir/fused_ir_emitter.h b/xla/service/llvm_ir/fused_ir_emitter.h index e3e5c8204f820f..229779fd65fc74 100644 --- a/xla/service/llvm_ir/fused_ir_emitter.h +++ b/xla/service/llvm_ir/fused_ir_emitter.h @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Value.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/elemental_ir_emitter.h" +#include "xla/service/llvm_ir/loop_emitter.h" namespace xla { diff --git a/xla/service/llvm_ir/ir_array.cc b/xla/service/llvm_ir/ir_array.cc index 38833e64526b57..29a4f4b467ebf4 100644 --- a/xla/service/llvm_ir/ir_array.cc +++ b/xla/service/llvm_ir/ir_array.cc @@ -22,21 +22,29 @@ limitations under the License. #include #include -#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" namespace xla { namespace llvm_ir { diff --git a/xla/service/llvm_ir/ir_array.h b/xla/service/llvm_ir/ir_array.h index 4bb0404d49fbb6..9ec78b09aaac8c 100644 --- a/xla/service/llvm_ir/ir_array.h +++ b/xla/service/llvm_ir/ir_array.h @@ -22,8 +22,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "xla/layout.h" #include "xla/map_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/xla/service/llvm_ir/ir_array_test.cc b/xla/service/llvm_ir/ir_array_test.cc index 356a6f03f00d5e..63ca0d8fa30d79 100644 --- a/xla/service/llvm_ir/ir_array_test.cc +++ b/xla/service/llvm_ir/ir_array_test.cc @@ -26,8 +26,11 @@ limitations under the License. #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/filecheck.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/llvm_ir/kernel_support_library.cc b/xla/service/llvm_ir/kernel_support_library.cc index 253630dd56c8f2..d5e72a7651d379 100644 --- a/xla/service/llvm_ir/kernel_support_library.cc +++ b/xla/service/llvm_ir/kernel_support_library.cc @@ -15,9 +15,24 @@ limitations under the License. #include "xla/service/llvm_ir/kernel_support_library.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "tsl/platform/errors.h" namespace xla { absl::Status KernelSupportLibrary::ForWithStatus( diff --git a/xla/service/llvm_ir/kernel_support_library.h b/xla/service/llvm_ir/kernel_support_library.h index 13c826abd3f3a2..77f463e2acd75a 100644 --- a/xla/service/llvm_ir/kernel_support_library.h +++ b/xla/service/llvm_ir/kernel_support_library.h @@ -18,12 +18,18 @@ limitations under the License. #include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "tsl/platform/status.h" namespace xla { // A thin wrapper around llvm_loop.h to make code generating structured control diff --git a/xla/service/llvm_ir/llvm_loop.cc b/xla/service/llvm_ir/llvm_loop.cc index bf5d49b78a828a..a8f635ef4552c1 100644 --- a/xla/service/llvm_ir/llvm_loop.cc +++ b/xla/service/llvm_ir/llvm_loop.cc @@ -21,12 +21,16 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" -#include "llvm/IR/Constants.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "xla/layout_util.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/shape_util.h" -#include "xla/types.h" +#include "xla/shape.h" #include "tsl/platform/logging.h" namespace xla { diff --git a/xla/service/llvm_ir/llvm_loop.h b/xla/service/llvm_ir/llvm_loop.h index 8e5bb5b2a38a42..a8636cde658900 100644 --- a/xla/service/llvm_ir/llvm_loop.h +++ b/xla/service/llvm_ir/llvm_loop.h @@ -23,9 +23,12 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Value.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape.h" #include "xla/types.h" #include "xla/xla_data.pb.h" diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index 8b3973d8847e48..399c335ff387ca 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -25,13 +25,16 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/casts.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" diff --git a/xla/service/llvm_ir/loop_emitter.cc b/xla/service/llvm_ir/loop_emitter.cc index 17ae97c578b0be..07da156d1ffd27 100644 --- a/xla/service/llvm_ir/loop_emitter.cc +++ b/xla/service/llvm_ir/loop_emitter.cc @@ -20,13 +20,20 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "xla/layout_util.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_loop.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/llvm_ir/loop_emitter.h b/xla/service/llvm_ir/loop_emitter.h index 40c6ee6e8c36f0..c691bbc8e6a57b 100644 --- a/xla/service/llvm_ir/loop_emitter.h +++ b/xla/service/llvm_ir/loop_emitter.h @@ -18,12 +18,16 @@ limitations under the License. #include +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_loop.h" +#include "xla/shape.h" namespace xla { namespace llvm_ir { diff --git a/xla/service/llvm_ir/math_ops.cc b/xla/service/llvm_ir/math_ops.cc index f33e8ec40bb3b8..fca7e6bd6f946b 100644 --- a/xla/service/llvm_ir/math_ops.cc +++ b/xla/service/llvm_ir/math_ops.cc @@ -15,6 +15,11 @@ limitations under the License. #include "xla/service/llvm_ir/math_ops.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "xla/service/llvm_ir/llvm_util.h" namespace xla { diff --git a/xla/service/llvm_ir/sort_util.cc b/xla/service/llvm_ir/sort_util.cc index c0da4efa1fef84..1e73067eb4ccc0 100644 --- a/xla/service/llvm_ir/sort_util.cc +++ b/xla/service/llvm_ir/sort_util.cc @@ -28,20 +28,23 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Value.h" -#include "xla/primitive_util.h" +#include "llvm/Support/Casting.h" +#include "xla/layout_util.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/kernel_support_library.h" -#include "xla/service/llvm_ir/llvm_loop.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/llvm_ir/loop_emitter.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" -#include "tsl/platform/status.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace xla { namespace llvm_ir { diff --git a/xla/service/llvm_ir/sort_util.h b/xla/service/llvm_ir/sort_util.h index 7cf78f9e7bc425..da0a5ada52b1d4 100644 --- a/xla/service/llvm_ir/sort_util.h +++ b/xla/service/llvm_ir/sort_util.h @@ -18,8 +18,10 @@ limitations under the License. #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/llvm_ir/ir_array.h" diff --git a/xla/service/llvm_ir/tuple_ops.cc b/xla/service/llvm_ir/tuple_ops.cc index ce064fefb31881..92093ffe0f0d6c 100644 --- a/xla/service/llvm_ir/tuple_ops.cc +++ b/xla/service/llvm_ir/tuple_ops.cc @@ -20,11 +20,18 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" diff --git a/xla/service/llvm_ir/tuple_ops.h b/xla/service/llvm_ir/tuple_ops.h index 5506e72fd87593..8ee4941e015d05 100644 --- a/xla/service/llvm_ir/tuple_ops.h +++ b/xla/service/llvm_ir/tuple_ops.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape.h" // Utilities for emitting LLVM IR related to HLO tuples. From 89089aa569aaff185a834f63d0c83fd59a58e93c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 27 Jul 2024 10:22:42 -0700 Subject: [PATCH 219/376] Use post order when searching for instructions to move in collective_pipeliner, such that those instructions are also cloned in post order to avoid an assert on last_cloned != nullptr. PiperOrigin-RevId: 656757358 --- xla/service/BUILD | 1 + xla/service/collective_pipeliner.cc | 6 +- xla/service/collective_pipeliner_test.cc | 97 ++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 09676869ddab91..b4a2419c0a016a 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -626,6 +626,7 @@ xla_cc_test( ":collective_pipeliner", ":hlo_parser", ":hlo_pass_pipeline", + ":host_memory_offload_annotations_hdr", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", diff --git a/xla/service/collective_pipeliner.cc b/xla/service/collective_pipeliner.cc index ccfacab79229ce..859b6c9b2540c2 100644 --- a/xla/service/collective_pipeliner.cc +++ b/xla/service/collective_pipeliner.cc @@ -1205,7 +1205,9 @@ void WhileLoopAnalysis::CollectCollectivesToMove( } int64_t count = 0; absl::flat_hash_map instruction_order; - for (auto* instr : while_body->MakeInstructionPostOrder()) { + std::vector instructions_post_order = + while_body->MakeInstructionPostOrder(); + for (auto* instr : instructions_post_order) { if (instr->opcode() == HloOpcode::kGetTupleElement) { if (index_range && instr->tuple_index() == 0) { index_ranges.insert({instr, *index_range}); @@ -1214,7 +1216,7 @@ void WhileLoopAnalysis::CollectCollectivesToMove( instruction_order[instr] = count++; } - for (auto* instr : while_body->instructions()) { + for (auto* instr : instructions_post_order) { if (direction == CollectivePipeliner::PipeliningDirection::kForward && (instr->operand_count() != 1 || instr->shape().dimensions_size() != diff --git a/xla/service/collective_pipeliner_test.cc b/xla/service/collective_pipeliner_test.cc index ab56cc1903cbcf..5492cbc582458d 100644 --- a/xla/service/collective_pipeliner_test.cc +++ b/xla/service/collective_pipeliner_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_pass_pipeline.h" +#include "xla/service/host_memory_offload_annotations.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" @@ -1344,6 +1345,102 @@ ENTRY entry { EXPECT_EQ(add_instr_loop->opcode(), HloOpcode::kAdd); } +TEST_F(CollectivePipelinerTest, + TransformIncrementIndexByOneBackwardsWithTwoDependentClones) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.k = bf16[3,1,2,128] get-tuple-element(param), index=2 + constant.2561 = s32[] constant(0) + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.k = bf16[1,1,2,128] dynamic-slice(get-tuple-element.k, select.1348, constant.2561, constant.2561, constant.2561), dynamic_slice_sizes={1,1,2,128} + r = bf16[1,2,128] reshape(dynamic-slice.k) + // To be peeled. + custom-call = bf16[1,2,128] custom-call(r), custom_call_target="MoveToDevice" + a = bf16[1,2,128] add(custom-call, custom-call), control-predecessors={constant.2559} + // To be peeled. + ag = bf16[1,8,128] all-gather(a), dimensions={1}, replica_groups={} + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, ag) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k), control-predecessors={a} +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128]) tuple(c0, p0, p1) + while = (s32[], bf16[3,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + auto is_all_gather_or_offloading = [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kAllGather || + instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget); + }; + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/false, + CollectivePipeliner::PipeliningDirection::kBackward, + is_all_gather_or_offloading) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + const int64_t while_count = absl::c_count_if( + module->entry_computation()->instructions(), + [](const HloInstruction* instruction) { + return HloPredicateIsOp(instruction); + }); + EXPECT_EQ(while_count, 1); + const HloInstruction* while_instr = + FindInstruction(module.get(), HloOpcode::kWhile); + const HloInstruction* tuple = while_instr->operand(0); + EXPECT_TRUE(tuple->HasControlDependencies()); + EXPECT_EQ(tuple->control_predecessors().size(), 1); + const HloInstruction* add_instr = tuple->control_predecessors()[0]; + EXPECT_EQ(add_instr->opcode(), HloOpcode::kAdd); + const HloComputation* comp = while_instr->while_body(); + const HloInstruction* root_loop = comp->root_instruction(); + EXPECT_TRUE(root_loop->HasControlDependencies()); + EXPECT_EQ(root_loop->control_predecessors().size(), 1); + const HloInstruction* add_instr_loop = root_loop->control_predecessors()[0]; + EXPECT_EQ(add_instr_loop->opcode(), HloOpcode::kAdd); + + EXPECT_NE(FindInstruction(module.get(), "custom-call.1"), nullptr); + EXPECT_NE(FindInstruction(module.get(), "custom-call.2"), nullptr); + EXPECT_NE(FindInstruction(module.get(), "ag.1"), nullptr); + EXPECT_NE(FindInstruction(module.get(), "ag.2"), nullptr); +} + TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneBackwardsCollectivePermute) { constexpr absl::string_view hlo_string = R"( From 95e3eea8d2aebd55160ed4185a38345ae98ab500 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 28 Jul 2024 08:52:19 -0700 Subject: [PATCH 220/376] Automated Code Change PiperOrigin-RevId: 656949964 --- xla/service/cpu/ir_emitter.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 8a6b3619e3ab29..13ede90a9192ab 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -4159,6 +4159,7 @@ absl::Status IrEmitter::EmitTargetElementLoop( .EmitLoop(IrName(target_op, desc))); std::vector tuple_operand_ptrs; + tuple_operand_ptrs.reserve(output_arrays.size()); for (int64_t i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } From 109fe83abc088b02fcd7a3df217caa898ecc1bd9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 28 Jul 2024 21:33:10 -0700 Subject: [PATCH 221/376] Automated Code Change PiperOrigin-RevId: 657064543 --- .../tsl/platform/default/human_readable_json.cc | 12 ++++++------ .../tsl/tsl/platform/default/posix_file_system.cc | 14 +++++++------- .../tsl/tsl/platform/default/posix_file_system.h | 2 +- third_party/tsl/tsl/platform/default/resource.cc | 5 +++-- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/third_party/tsl/tsl/platform/default/human_readable_json.cc b/third_party/tsl/tsl/platform/default/human_readable_json.cc index 4f70c155dd4ef3..167cdd2b891312 100644 --- a/third_party/tsl/tsl/platform/default/human_readable_json.cc +++ b/third_party/tsl/tsl/platform/default/human_readable_json.cc @@ -40,9 +40,9 @@ absl::StatusOr ProtoToHumanReadableJson( // Convert error_msg google::protobuf::StringPiece to // tsl::StringPiece. auto error_msg = status.message(); - return errors::Internal( - strings::StrCat("Could not convert proto to JSON string: ", - StringPiece(error_msg.data(), error_msg.length()))); + return errors::Internal(strings::StrCat( + "Could not convert proto to JSON string: ", + absl::string_view(error_msg.data(), error_msg.length()))); } return std::move(result); } @@ -60,9 +60,9 @@ absl::Status HumanReadableJsonToProto(const string& str, // Convert error_msg google::protobuf::StringPiece to // tsl::StringPiece. auto error_msg = status.message(); - return errors::Internal( - strings::StrCat("Could not convert JSON string to proto: ", - StringPiece(error_msg.data(), error_msg.length()))); + return errors::Internal(strings::StrCat( + "Could not convert JSON string to proto: ", + absl::string_view(error_msg.data(), error_msg.length()))); } return absl::OkStatus(); } diff --git a/third_party/tsl/tsl/platform/default/posix_file_system.cc b/third_party/tsl/tsl/platform/default/posix_file_system.cc index c87ba180197449..d1b2109823f35e 100644 --- a/third_party/tsl/tsl/platform/default/posix_file_system.cc +++ b/third_party/tsl/tsl/platform/default/posix_file_system.cc @@ -60,12 +60,12 @@ class PosixRandomAccessFile : public RandomAccessFile { } } - absl::Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { *result = filename_; return absl::OkStatus(); } - absl::Status Read(uint64 offset, size_t n, StringPiece* result, + absl::Status Read(uint64 offset, size_t n, absl::string_view* result, char* scratch) const override { absl::Status s; char* dst = scratch; @@ -93,7 +93,7 @@ class PosixRandomAccessFile : public RandomAccessFile { s = IOError(filename_, errno); } } - *result = StringPiece(scratch, dst - scratch); + *result = absl::string_view(scratch, dst - scratch); return s; } @@ -114,7 +114,7 @@ class PosixRandomAccessFile : public RandomAccessFile { " bytes for file reading."); } - StringPiece tmp; + absl::string_view tmp; absl::Status s = Read(offset, n, &tmp, scratch); absl::Cord tmp_cord = absl::MakeCordFromExternal( @@ -142,7 +142,7 @@ class PosixWritableFile : public WritableFile { } } - absl::Status Append(StringPiece data) override { + absl::Status Append(absl::string_view data) override { size_t r = fwrite(data.data(), 1, data.size(), file_); if (r != data.size()) { return IOError(filename_, errno); @@ -182,7 +182,7 @@ class PosixWritableFile : public WritableFile { return absl::OkStatus(); } - absl::Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { *result = filename_; return absl::OkStatus(); } @@ -308,7 +308,7 @@ absl::Status PosixFileSystem::GetChildren(const string& dir, } struct dirent* entry; while ((entry = readdir(d)) != nullptr) { - StringPiece basename = entry->d_name; + absl::string_view basename = entry->d_name; if ((basename != ".") && (basename != "..")) { result->push_back(entry->d_name); } diff --git a/third_party/tsl/tsl/platform/default/posix_file_system.h b/third_party/tsl/tsl/platform/default/posix_file_system.h index 877a65319c6df7..f22d4c9957197f 100644 --- a/third_party/tsl/tsl/platform/default/posix_file_system.h +++ b/third_party/tsl/tsl/platform/default/posix_file_system.h @@ -76,7 +76,7 @@ class PosixFileSystem : public FileSystem { class LocalPosixFileSystem : public PosixFileSystem { public: string TranslateName(const string& name) const override { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(name, &scheme, &host, &path); return string(path); } diff --git a/third_party/tsl/tsl/platform/default/resource.cc b/third_party/tsl/tsl/platform/default/resource.cc index b57f7774963cfb..f2747160790c93 100644 --- a/third_party/tsl/tsl/platform/default/resource.cc +++ b/third_party/tsl/tsl/platform/default/resource.cc @@ -21,10 +21,11 @@ namespace tsl { class ResourceTagger::ResourceTaggerImpl { public: - ResourceTaggerImpl(StringPiece key, StringPiece value) {} + ResourceTaggerImpl(absl::string_view key, absl::string_view value) {} }; -ResourceTagger::ResourceTagger(StringPiece key, StringPiece value) {} +ResourceTagger::ResourceTagger(absl::string_view key, absl::string_view value) { +} ResourceTagger::~ResourceTagger() {} From 46bef5f8392e8867352896b6702a52517b323e44 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 28 Jul 2024 21:34:14 -0700 Subject: [PATCH 222/376] Automated Code Change PiperOrigin-RevId: 657064665 --- third_party/tsl/tsl/platform/env.cc | 12 ++-- third_party/tsl/tsl/platform/env.h | 2 +- third_party/tsl/tsl/platform/errors.h | 40 +++++------ third_party/tsl/tsl/platform/file_system.cc | 72 ++++++++++--------- third_party/tsl/tsl/platform/file_system.h | 35 ++++----- .../tsl/tsl/platform/file_system_helper.cc | 4 +- third_party/tsl/tsl/platform/fingerprint.h | 6 +- 7 files changed, 87 insertions(+), 84 deletions(-) diff --git a/third_party/tsl/tsl/platform/env.cc b/third_party/tsl/tsl/platform/env.cc index 789725e8856b94..169bd0a2695b26 100644 --- a/third_party/tsl/tsl/platform/env.cc +++ b/third_party/tsl/tsl/platform/env.cc @@ -111,7 +111,7 @@ Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {} absl::Status Env::GetFileSystemForFile(const std::string& fname, FileSystem** result) { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(fname, &scheme, &host, &path); FileSystem* file_system = file_system_registry_->Lookup(std::string(scheme)); if (!file_system) { @@ -231,7 +231,7 @@ bool Env::FilesExist(const std::vector& files, std::vector* status) { std::unordered_map> files_per_fs; for (const auto& file : files) { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(file, &scheme, &host, &path); files_per_fs[string(scheme)].push_back(file); } @@ -486,7 +486,7 @@ absl::Status ReadFileToString(Env* env, const string& fname, string* data) { } data->resize(file_size); char* p = &*data->begin(); - StringPiece result; + absl::string_view result; s = file->Read(0, file_size, &result, p); if (!s.ok()) { data->clear(); @@ -503,7 +503,7 @@ absl::Status ReadFileToString(Env* env, const string& fname, string* data) { } absl::Status WriteStringToFile(Env* env, const string& fname, - const StringPiece& data) { + const absl::string_view& data) { std::unique_ptr file; absl::Status s = env->NewWritableFile(fname, &file); if (!s.ok()) { @@ -536,7 +536,7 @@ absl::Status FileSystemCopyFile(FileSystem* src_fs, const string& src, std::unique_ptr scratch(new char[kCopyFileBufferSize]); absl::Status s = absl::OkStatus(); while (s.ok()) { - StringPiece result; + absl::string_view result; s = src_file->Read(offset, kCopyFileBufferSize, &result, scratch.get()); if (!(s.ok() || s.code() == error::OUT_OF_RANGE)) { return s; @@ -562,7 +562,7 @@ class FileStream : public protobuf::io::ZeroCopyInputStream { absl::Status status() const { return status_; } bool Next(const void** data, int* size) override { - StringPiece result; + absl::string_view result; absl::Status s = file_->Read(pos_, kBufSize, &result, scratch_); if (result.empty()) { status_ = s; diff --git a/third_party/tsl/tsl/platform/env.h b/third_party/tsl/tsl/platform/env.h index 0952517f9b7f8c..f814e39339ecc8 100644 --- a/third_party/tsl/tsl/platform/env.h +++ b/third_party/tsl/tsl/platform/env.h @@ -640,7 +640,7 @@ absl::Status ReadFileToString(Env* env, const std::string& fname, /// A utility routine: write contents of `data` to file named `fname` /// (overwriting existing contents, if any). absl::Status WriteStringToFile(Env* env, const std::string& fname, - const StringPiece& data); + const absl::string_view& data); /// Write binary representation of "proto" to the named file. absl::Status WriteBinaryProto(Env* env, const std::string& fname, diff --git a/third_party/tsl/tsl/platform/errors.h b/third_party/tsl/tsl/platform/errors.h index 1d86af35cd3efd..9be69959661e8a 100644 --- a/third_party/tsl/tsl/platform/errors.h +++ b/third_party/tsl/tsl/platform/errors.h @@ -95,7 +95,7 @@ inline std::unordered_map GetPayloads( const absl::Status& status) { std::unordered_map payloads; status.ForEachPayload( - [&payloads](::tsl::StringPiece key, const absl::Cord& value) { + [&payloads](absl::string_view key, const absl::Cord& value) { payloads[std::string(key)] = std::string(value); }); return payloads; @@ -114,7 +114,7 @@ inline void InsertPayloads( // Copies all payloads from one Status to another. Will overwrite existing // payloads in the destination if they exist with the same key. inline void CopyPayloads(const absl::Status& from, absl::Status& to) { - from.ForEachPayload([&to](::tsl::StringPiece key, const absl::Cord& value) { + from.ForEachPayload([&to](absl::string_view key, const absl::Cord& value) { to.SetPayload(key, value); }); } @@ -122,7 +122,7 @@ inline void CopyPayloads(const absl::Status& from, absl::Status& to) { #if defined(PLATFORM_GOOGLE) // Creates a new status with the given code, message and payloads. inline absl::Status Create( - absl::StatusCode code, ::tsl::StringPiece message, + absl::StatusCode code, absl::string_view message, const std::unordered_map& payloads, absl::SourceLocation loc = absl::SourceLocation::current()) { absl::Status status(code, message, loc); @@ -131,7 +131,7 @@ inline absl::Status Create( } // Returns a new Status, replacing its message with the given. inline absl::Status CreateWithUpdatedMessage(const absl::Status& status, - ::tsl::StringPiece message) { + absl::string_view message) { auto locations = status.GetSourceLocations(); auto initial_loc = locations.empty() ? absl::SourceLocation::current() : locations[0]; @@ -206,7 +206,7 @@ absl::Status Cancelled(Args... args) { } template absl::Status CancelledWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kCancelled, message, payloads); } @@ -264,7 +264,7 @@ ::absl::Status InvalidArgument( } template ::absl::Status InvalidArgumentWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads, absl::SourceLocation loc = absl::SourceLocation::current()) { return errors::Create(absl::StatusCode::kInvalidArgument, message, payloads, @@ -340,7 +340,7 @@ ::absl::Status NotFound( } template ::absl::Status NotFoundWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads, absl::SourceLocation loc = absl::SourceLocation::current()) { return errors::Create(absl::StatusCode::kNotFound, message, payloads, loc); @@ -384,7 +384,7 @@ absl::Status AlreadyExists(Args... args) { } template absl::Status AlreadyExistsWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kAlreadyExists, message, payloads); } @@ -398,7 +398,7 @@ absl::Status ResourceExhausted(Args... args) { } template absl::Status ResourceExhaustedWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kResourceExhausted, message, payloads); @@ -413,7 +413,7 @@ absl::Status Unavailable(Args... args) { } template absl::Status UnavailableWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kUnavailable, message, payloads); } @@ -427,7 +427,7 @@ absl::Status FailedPrecondition(Args... args) { } template absl::Status FailedPreconditionWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kFailedPrecondition, message, payloads); @@ -442,7 +442,7 @@ absl::Status OutOfRange(Args... args) { } template absl::Status OutOfRangeWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kOutOfRange, message, payloads); } @@ -456,7 +456,7 @@ absl::Status Unimplemented(Args... args) { } template absl::Status UnimplementedWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kUnimplemented, message, payloads); } @@ -470,7 +470,7 @@ absl::Status Internal(Args... args) { } template absl::Status InternalWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kInternal, message, payloads); } @@ -484,7 +484,7 @@ absl::Status Aborted(Args... args) { } template absl::Status AbortedWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kAborted, message, payloads); } @@ -498,7 +498,7 @@ absl::Status DeadlineExceeded(Args... args) { } template absl::Status DeadlineExceededWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kDeadlineExceeded, message, payloads); } @@ -512,7 +512,7 @@ absl::Status DataLoss(Args... args) { } template absl::Status DataLossWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kDataLoss, message, payloads); } @@ -526,7 +526,7 @@ absl::Status Unknown(Args... args) { } template absl::Status UnknownPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kUnknown, message, payloads); } @@ -539,7 +539,7 @@ absl::Status PermissionDenied(Args... args) { } template absl::Status PermissionDeniedWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kPermissionDenied, message, payloads); } @@ -553,7 +553,7 @@ absl::Status Unauthenticated(Args... args) { } template absl::Status UnauthenticatedWithPayloads( - const ::tsl::StringPiece& message, + const absl::string_view& message, const std::unordered_map& payloads) { return errors::Create(absl::StatusCode::kUnauthenticated, message, payloads); } diff --git a/third_party/tsl/tsl/platform/file_system.cc b/third_party/tsl/tsl/platform/file_system.cc index ee385af7354074..453e04b3942e8a 100644 --- a/third_party/tsl/tsl/platform/file_system.cc +++ b/third_party/tsl/tsl/platform/file_system.cc @@ -67,7 +67,7 @@ string FileSystem::TranslateName(const string& name) const { if (name.empty()) return name; // Otherwise, properly separate the URI components and clean the path one - StringPiece scheme, host, path; + absl::string_view scheme, host, path; this->ParseURI(name, &scheme, &host, &path); // If `path` becomes empty, return `/` (`file://` should be `/`), not `.`. @@ -195,9 +195,9 @@ absl::Status FileSystem::DeleteRecursively(const string& dirname, absl::Status FileSystem::RecursivelyCreateDir(const string& dirname, TransactionToken* token) { - StringPiece scheme, host, remaining_dir; + absl::string_view scheme, host, remaining_dir; this->ParseURI(dirname, &scheme, &host, &remaining_dir); - std::vector sub_dirs; + std::vector sub_dirs; while (!remaining_dir.empty()) { std::string current_entry = this->CreateURI(scheme, host, remaining_dir); absl::Status exists_status = FileExists(current_entry); @@ -218,7 +218,7 @@ absl::Status FileSystem::RecursivelyCreateDir(const string& dirname, return exists_status; } // Basename returns "" for / ending dirs. - if (!str_util::EndsWith(remaining_dir, "/")) { + if (!absl::EndsWith(remaining_dir, "/")) { sub_dirs.push_back(this->Basename(remaining_dir)); } remaining_dir = this->Dirname(remaining_dir); @@ -229,7 +229,7 @@ absl::Status FileSystem::RecursivelyCreateDir(const string& dirname, // Now create the directories. string built_path(remaining_dir); - for (const StringPiece sub_dir : sub_dirs) { + for (const absl::string_view sub_dir : sub_dirs) { built_path = this->JoinPath(built_path, sub_dir); absl::Status status = CreateDir(this->CreateURI(scheme, host, built_path)); if (!status.ok() && status.code() != absl::StatusCode::kAlreadyExists) { @@ -246,10 +246,11 @@ absl::Status FileSystem::CopyFile(const string& src, const string& target, char FileSystem::Separator() const { return '/'; } -string FileSystem::JoinPathImpl(std::initializer_list paths) { +string FileSystem::JoinPathImpl( + std::initializer_list paths) { string result; - for (StringPiece path : paths) { + for (absl::string_view path : paths) { if (path.empty()) continue; if (result.empty()) { @@ -275,9 +276,9 @@ string FileSystem::JoinPathImpl(std::initializer_list paths) { return result; } -std::pair FileSystem::SplitPath( - StringPiece uri) const { - StringPiece scheme, host, path; +std::pair FileSystem::SplitPath( + absl::string_view uri) const { + absl::string_view scheme, host, path; ParseURI(uri, &scheme, &host, &path); // We have 3 cases of results from `ParseURI`: @@ -305,7 +306,7 @@ std::pair FileSystem::SplitPath( // Case 1 above if (path.empty()) { - return std::make_pair(StringPiece(), StringPiece()); + return std::make_pair(absl::string_view(), absl::string_view()); } size_t pos = path.rfind(this->Separator()); @@ -325,54 +326,54 @@ std::pair FileSystem::SplitPath( #endif // Handle the case with no SEP in 'path'. - if (pos == StringPiece::npos) { + if (pos == absl::string_view::npos) { if (host.empty()) { // Case 3 above, `uri` and `path` point to the same thing // We are returning all of the `path` as basename here. - return std::make_pair(StringPiece(), path); + return std::make_pair(absl::string_view(), path); } // Safe to do this arithmetic here, we are in case 2 above - return std::make_pair(StringPiece(uri.data(), host.end() - uri.begin()), - path); + return std::make_pair( + absl::string_view(uri.data(), host.end() - uri.begin()), path); } // Handle the case with a single leading '/' in 'path'. if (pos == 0) { return std::make_pair( - StringPiece(uri.data(), path.begin() + 1 - uri.begin()), - StringPiece(path.data() + 1, path.size() - 1)); + absl::string_view(uri.data(), path.begin() + 1 - uri.begin()), + absl::string_view(path.data() + 1, path.size() - 1)); } return std::make_pair( - StringPiece(uri.data(), path.begin() + pos - uri.begin()), - StringPiece(path.data() + pos + 1, path.size() - (pos + 1))); + absl::string_view(uri.data(), path.begin() + pos - uri.begin()), + absl::string_view(path.data() + pos + 1, path.size() - (pos + 1))); } -bool FileSystem::IsAbsolutePath(StringPiece path) const { +bool FileSystem::IsAbsolutePath(absl::string_view path) const { return !path.empty() && path[0] == '/'; } -StringPiece FileSystem::Dirname(StringPiece path) const { +absl::string_view FileSystem::Dirname(absl::string_view path) const { return this->SplitPath(path).first; } -StringPiece FileSystem::Basename(StringPiece path) const { +absl::string_view FileSystem::Basename(absl::string_view path) const { return this->SplitPath(path).second; } -StringPiece FileSystem::Extension(StringPiece path) const { - StringPiece basename = this->Basename(path); +absl::string_view FileSystem::Extension(absl::string_view path) const { + absl::string_view basename = this->Basename(path); size_t pos = basename.rfind('.'); - if (pos == StringPiece::npos) { - return StringPiece(path.data() + path.size(), 0); + if (pos == absl::string_view::npos) { + return absl::string_view(path.data() + path.size(), 0); } else { - return StringPiece(path.data() + pos + 1, path.size() - (pos + 1)); + return absl::string_view(path.data() + pos + 1, path.size() - (pos + 1)); } } -string FileSystem::CleanPath(StringPiece unclean_path) const { +string FileSystem::CleanPath(absl::string_view unclean_path) const { string path(unclean_path); const char* src = path.c_str(); string::iterator dst = path.begin(); @@ -453,8 +454,9 @@ string FileSystem::CleanPath(StringPiece unclean_path) const { return path; } -void FileSystem::ParseURI(StringPiece remaining, StringPiece* scheme, - StringPiece* host, StringPiece* path) const { +void FileSystem::ParseURI(absl::string_view remaining, + absl::string_view* scheme, absl::string_view* host, + absl::string_view* path) const { // 0. Parse scheme // Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]* // TODO(keveman): Allow "+" and "-" in the scheme. @@ -466,8 +468,8 @@ void FileSystem::ParseURI(StringPiece remaining, StringPiece* scheme, .OneLiteral("://") .GetResult(&remaining, scheme)) { // If there's no scheme, assume the entire string is a path. - *scheme = StringPiece(); - *host = StringPiece(); + *scheme = absl::string_view(); + *host = absl::string_view(); *path = remaining; return; } @@ -476,7 +478,7 @@ void FileSystem::ParseURI(StringPiece remaining, StringPiece* scheme, if (!strings::Scanner(remaining).ScanUntil('/').GetResult(&remaining, host)) { // No path, so the rest of the URI is the host. *host = remaining; - *path = StringPiece(); + *path = absl::string_view(); return; } @@ -484,8 +486,8 @@ void FileSystem::ParseURI(StringPiece remaining, StringPiece* scheme, *path = remaining; } -string FileSystem::CreateURI(StringPiece scheme, StringPiece host, - StringPiece path) const { +string FileSystem::CreateURI(absl::string_view scheme, absl::string_view host, + absl::string_view path) const { if (scheme.empty()) { return string(path); } diff --git a/third_party/tsl/tsl/platform/file_system.h b/third_party/tsl/tsl/platform/file_system.h index 67209ed491055f..8b48788261368e 100644 --- a/third_party/tsl/tsl/platform/file_system.h +++ b/third_party/tsl/tsl/platform/file_system.h @@ -410,25 +410,26 @@ class FileSystem { /// \brief Split a path to its basename and dirname. /// /// Helper function for Basename and Dirname. - std::pair SplitPath(StringPiece uri) const; + std::pair SplitPath( + absl::string_view uri) const; /// \brief returns the final file name in the given path. /// /// Returns the part of the path after the final "/". If there is no /// "/" in the path, the result is the same as the input. - virtual StringPiece Basename(StringPiece path) const; + virtual absl::string_view Basename(absl::string_view path) const; /// \brief Returns the part of the path before the final "/". /// /// If there is a single leading "/" in the path, the result will be the /// leading "/". If there is no "/" in the path, the result is the empty /// prefix of the input. - StringPiece Dirname(StringPiece path) const; + absl::string_view Dirname(absl::string_view path) const; /// \brief Returns the part of the basename of path after the final ".". /// /// If there is no "." in the basename, the result is empty. - StringPiece Extension(StringPiece path) const; + absl::string_view Extension(absl::string_view path) const; /// \brief Clean duplicate and trailing, "/"s, and resolve ".." and ".". /// @@ -436,16 +437,16 @@ class FileSystem { /// invoke any system calls (getcwd(2)) in order to resolve relative /// paths with respect to the actual working directory. That is, this is /// purely string manipulation, completely independent of process state. - std::string CleanPath(StringPiece path) const; + std::string CleanPath(absl::string_view path) const; /// \brief Creates a URI from a scheme, host, and path. /// /// If the scheme is empty, we just return the path. - std::string CreateURI(StringPiece scheme, StringPiece host, - StringPiece path) const; + std::string CreateURI(absl::string_view scheme, absl::string_view host, + absl::string_view path) const; /// \brief Return true if path is absolute. - bool IsAbsolutePath(tsl::StringPiece path) const; + bool IsAbsolutePath(absl::string_view path) const; #ifndef SWIG // variadic templates /// \brief Join multiple paths together. @@ -469,7 +470,7 @@ class FileSystem { } #endif /* SWIG */ - std::string JoinPathImpl(std::initializer_list paths); + std::string JoinPathImpl(std::initializer_list paths); /// \brief Populates the scheme, host, and path from a URI. /// @@ -481,8 +482,8 @@ class FileSystem { /// passed string is assumed to be a path /// - If the URI omits the path (e.g. file://host), then the path is left /// empty. - void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host, - StringPiece* path) const; + void ParseURI(absl::string_view remaining, absl::string_view* scheme, + absl::string_view* host, absl::string_view* path) const; // Transaction related API @@ -710,7 +711,7 @@ class WrappedFileSystem : public FileSystem { char Separator() const override { return fs_->Separator(); } - StringPiece Basename(StringPiece path) const override { + absl::string_view Basename(absl::string_view path) const override { return fs_->Basename(path); } @@ -761,7 +762,7 @@ class RandomAccessFile { /// /// This is an optional operation that may not be implemented by every /// filesystem. - virtual absl::Status Name(StringPiece* result) const { + virtual absl::Status Name(absl::string_view* result) const { return errors::Unimplemented("This filesystem does not support Name()"); } @@ -780,7 +781,7 @@ class RandomAccessFile { /// because of EOF. /// /// Safe for concurrent use by multiple threads. - virtual absl::Status Read(uint64 offset, size_t n, StringPiece* result, + virtual absl::Status Read(uint64 offset, size_t n, absl::string_view* result, char* scratch) const = 0; #if defined(TF_CORD_SUPPORT) @@ -807,12 +808,12 @@ class WritableFile { virtual ~WritableFile() = default; /// \brief Append 'data' to the file. - virtual absl::Status Append(StringPiece data) = 0; + virtual absl::Status Append(absl::string_view data) = 0; #if defined(TF_CORD_SUPPORT) // \brief Append 'data' to the file. virtual absl::Status Append(const absl::Cord& cord) { - for (StringPiece chunk : cord.Chunks()) { + for (absl::string_view chunk : cord.Chunks()) { TF_RETURN_IF_ERROR(Append(chunk)); } return absl::OkStatus(); @@ -844,7 +845,7 @@ class WritableFile { /// /// This is an optional operation that may not be implemented by every /// filesystem. - virtual absl::Status Name(StringPiece* result) const { + virtual absl::Status Name(absl::string_view* result) const { return errors::Unimplemented("This filesystem does not support Name()"); } diff --git a/third_party/tsl/tsl/platform/file_system_helper.cc b/third_party/tsl/tsl/platform/file_system_helper.cc index 04dc6a0420516c..bfbea9808675e2 100644 --- a/third_party/tsl/tsl/platform/file_system_helper.cc +++ b/third_party/tsl/tsl/platform/file_system_helper.cc @@ -79,7 +79,7 @@ static std::string PatchPattern(const std::string& pattern) { static std::vector AllDirectoryPrefixes(const std::string& d) { std::vector dirs; const std::string patched = PatchPattern(d); - StringPiece dir(patched); + absl::string_view dir(patched); // If the pattern ends with a `/` (or `\\` on Windows), we need to strip it // otherwise we would have one additional matching step and the result set @@ -94,7 +94,7 @@ static std::vector AllDirectoryPrefixes(const std::string& d) { while (!dir.empty()) { dirs.emplace_back(dir); - StringPiece new_dir(io::Dirname(dir)); + absl::string_view new_dir(io::Dirname(dir)); // io::Dirname("/") returns "/" so we need to break the loop. // On Windows, io::Dirname("C:\\") would return "C:\\", so we check for // identity of the result instead of checking for dir[0] == `/`. diff --git a/third_party/tsl/tsl/platform/fingerprint.h b/third_party/tsl/tsl/platform/fingerprint.h index bb961fd89c1742..b5be7200332e41 100644 --- a/third_party/tsl/tsl/platform/fingerprint.h +++ b/third_party/tsl/tsl/platform/fingerprint.h @@ -78,7 +78,7 @@ inline uint64_t FingerprintCat64(const uint64_t fp1, const uint64_t fp2) { // This is a portable fingerprint interface for strings that will never change. // However, it is not suitable for cryptography. -inline uint64_t Fingerprint64(const tsl::StringPiece s) { +inline uint64_t Fingerprint64(const absl::string_view s) { #ifdef USE_OSS_FARMHASH return ::util::Fingerprint64(s.data(), s.size()); #else @@ -92,7 +92,7 @@ inline uint64_t Fingerprint64(const tsl::StringPiece s) { } // 32-bit variant of Fingerprint64 above (same properties and caveats apply). -inline uint32_t Fingerprint32(const tsl::StringPiece s) { +inline uint32_t Fingerprint32(const absl::string_view s) { #ifdef USE_OSS_FARMHASH return ::util::Fingerprint32(s.data(), s.size()); #else @@ -101,7 +101,7 @@ inline uint32_t Fingerprint32(const tsl::StringPiece s) { } // 128-bit variant of Fingerprint64 above (same properties and caveats apply). -inline Fprint128 Fingerprint128(const tsl::StringPiece s) { +inline Fprint128 Fingerprint128(const absl::string_view s) { #ifdef USE_OSS_FARMHASH const auto fingerprint = ::util::Fingerprint128(s.data(), s.size()); return {::util::Uint128Low64(fingerprint), From 167344095b39a039eb09742efd2da36a9e02d89d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 28 Jul 2024 21:35:44 -0700 Subject: [PATCH 223/376] Automated Code Change PiperOrigin-RevId: 657064902 --- third_party/tsl/tsl/platform/hash.h | 6 +- third_party/tsl/tsl/platform/hash_test.cc | 16 +- third_party/tsl/tsl/platform/numbers.cc | 22 +-- third_party/tsl/tsl/platform/numbers.h | 30 ++-- third_party/tsl/tsl/platform/numbers_test.cc | 26 ++-- third_party/tsl/tsl/platform/path.cc | 70 +++++---- third_party/tsl/tsl/platform/path.h | 24 +-- third_party/tsl/tsl/platform/path_test.cc | 14 +- .../tsl/tsl/platform/ram_file_system.h | 8 +- third_party/tsl/tsl/platform/resource.h | 2 +- .../tsl/tsl/platform/retrying_file_system.h | 8 +- .../tsl/platform/retrying_file_system_test.cc | 18 +-- third_party/tsl/tsl/platform/scanner.cc | 5 +- third_party/tsl/tsl/platform/scanner.h | 18 ++- third_party/tsl/tsl/platform/scanner_test.cc | 30 ++-- third_party/tsl/tsl/platform/status_test.cc | 4 +- third_party/tsl/tsl/platform/str_util.cc | 24 +-- third_party/tsl/tsl/platform/str_util.h | 34 +++-- third_party/tsl/tsl/platform/str_util_test.cc | 143 +++++++++--------- third_party/tsl/tsl/platform/strcat.cc | 15 +- third_party/tsl/tsl/platform/strcat.h | 14 +- third_party/tsl/tsl/platform/strcat_test.cc | 4 +- .../tsl/tsl/platform/stringpiece_test.cc | 14 +- third_party/tsl/tsl/platform/tracing.cc | 2 +- third_party/tsl/tsl/platform/tracing.h | 4 +- third_party/tsl/tsl/platform/tstring.h | 20 +-- third_party/tsl/tsl/platform/tstring_test.cc | 10 +- 27 files changed, 298 insertions(+), 287 deletions(-) diff --git a/third_party/tsl/tsl/platform/hash.h b/third_party/tsl/tsl/platform/hash.h index d8d676a72d3b04..2e18b440a263d3 100644 --- a/third_party/tsl/tsl/platform/hash.h +++ b/third_party/tsl/tsl/platform/hash.h @@ -107,12 +107,12 @@ struct hash { }; template <> -struct hash { - size_t operator()(StringPiece sp) const { +struct hash { + size_t operator()(absl::string_view sp) const { return static_cast(Hash64(sp.data(), sp.size())); } }; -using StringPieceHasher = ::tsl::hash; +using StringPieceHasher = ::tsl::hash; template struct hash> { diff --git a/third_party/tsl/tsl/platform/hash_test.cc b/third_party/tsl/tsl/platform/hash_test.cc index 665bd5ec7e9ec7..7b4752e729107c 100644 --- a/third_party/tsl/tsl/platform/hash_test.cc +++ b/third_party/tsl/tsl/platform/hash_test.cc @@ -87,10 +87,10 @@ BENCHMARK(BM_Hash32)->Range(1, 1024); TEST(StringPieceHasher, Equality) { StringPieceHasher hasher; - StringPiece s1("foo"); - StringPiece s2("bar"); - StringPiece s3("baz"); - StringPiece s4("zot"); + absl::string_view s1("foo"); + absl::string_view s2("bar"); + absl::string_view s3("baz"); + absl::string_view s4("zot"); EXPECT_TRUE(hasher(s1) != hasher(s2)); EXPECT_TRUE(hasher(s1) != hasher(s3)); @@ -110,11 +110,11 @@ TEST(StringPieceHasher, HashMap) { string s2("bar"); string s3("baz"); - StringPiece p1(s1); - StringPiece p2(s2); - StringPiece p3(s3); + absl::string_view p1(s1); + absl::string_view p2(s2); + absl::string_view p3(s3); - std::unordered_map map; + std::unordered_map map; map.insert(std::make_pair(p1, 0)); map.insert(std::make_pair(p2, 1)); diff --git a/third_party/tsl/tsl/platform/numbers.cc b/third_party/tsl/tsl/platform/numbers.cc index c8eb0ab0a441b7..7239e6fff7a51d 100644 --- a/third_party/tsl/tsl/platform/numbers.cc +++ b/third_party/tsl/tsl/platform/numbers.cc @@ -220,16 +220,16 @@ size_t DoubleToBuffer(double value, char* buffer) { } namespace { -char SafeFirstChar(StringPiece str) { +char SafeFirstChar(absl::string_view str) { if (str.empty()) return '\0'; return str[0]; } -void SkipSpaces(StringPiece* str) { +void SkipSpaces(absl::string_view* str) { while (isspace(SafeFirstChar(*str))) str->remove_prefix(1); } } // namespace -bool safe_strto64(StringPiece str, int64_t* value) { +bool safe_strto64(absl::string_view str, int64_t* value) { SkipSpaces(&str); int64_t vlimit = kint64max; @@ -270,7 +270,7 @@ bool safe_strto64(StringPiece str, int64_t* value) { return true; } -bool safe_strtou64(StringPiece str, uint64_t* value) { +bool safe_strtou64(absl::string_view str, uint64_t* value) { SkipSpaces(&str); if (!isdigit(SafeFirstChar(str))) return false; @@ -291,7 +291,7 @@ bool safe_strtou64(StringPiece str, uint64_t* value) { return true; } -bool safe_strto32(StringPiece str, int32_t* value) { +bool safe_strto32(absl::string_view str, int32_t* value) { SkipSpaces(&str); int64_t vmax = kint32max; @@ -321,7 +321,7 @@ bool safe_strto32(StringPiece str, int32_t* value) { return true; } -bool safe_strtou32(StringPiece str, uint32_t* value) { +bool safe_strtou32(absl::string_view str, uint32_t* value) { SkipSpaces(&str); if (!isdigit(SafeFirstChar(str))) return false; @@ -341,7 +341,7 @@ bool safe_strtou32(StringPiece str, uint32_t* value) { return true; } -bool safe_strtof(StringPiece str, float* value) { +bool safe_strtof(absl::string_view str, float* value) { int processed_characters_count = -1; auto len = str.size(); @@ -354,7 +354,7 @@ bool safe_strtof(StringPiece str, float* value) { return processed_characters_count > 0; } -bool safe_strtod(StringPiece str, double* value) { +bool safe_strtod(absl::string_view str, double* value) { int processed_characters_count = -1; auto len = str.size(); @@ -417,7 +417,7 @@ bool StringToFp(const std::string& s, Fprint* fp) { } } -StringPiece Uint64ToHexString(uint64_t v, char* buf) { +absl::string_view Uint64ToHexString(uint64_t v, char* buf) { static const char* hexdigits = "0123456789abcdef"; const int num_byte = 16; buf[num_byte] = '\0'; @@ -425,10 +425,10 @@ StringPiece Uint64ToHexString(uint64_t v, char* buf) { buf[i] = hexdigits[v & 0xf]; v >>= 4; } - return StringPiece(buf, num_byte); + return absl::string_view(buf, num_byte); } -bool HexStringToUint64(const StringPiece& s, uint64_t* result) { +bool HexStringToUint64(const absl::string_view& s, uint64_t* result) { uint64_t v = 0; if (s.empty()) { return false; diff --git a/third_party/tsl/tsl/platform/numbers.h b/third_party/tsl/tsl/platform/numbers.h index ca480a04e0d5a9..0d62f425361927 100644 --- a/third_party/tsl/tsl/platform/numbers.h +++ b/third_party/tsl/tsl/platform/numbers.h @@ -85,66 +85,66 @@ bool StringToFp(const std::string& s, Fprint* fp); // Convert a 64-bit fingerprint value to an ASCII representation that // is terminated by a '\0'. // Buf must point to an array of at least kFastToBufferSize characters -StringPiece Uint64ToHexString(uint64_t v, char* buf); +absl::string_view Uint64ToHexString(uint64_t v, char* buf); // Attempt to parse a uint64 in the form encoded by FastUint64ToHexString. If // successful, stores the value in *v and returns true. Otherwise, // returns false. -bool HexStringToUint64(const StringPiece& s, uint64_t* result); +bool HexStringToUint64(const absl::string_view& s, uint64_t* result); // Convert strings to 32bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strto32(StringPiece str, int32_t* value); +bool safe_strto32(absl::string_view str, int32_t* value); // Convert strings to unsigned 32bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strtou32(StringPiece str, uint32_t* value); +bool safe_strtou32(absl::string_view str, uint32_t* value); // Convert strings to 64bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strto64(StringPiece str, int64_t* value); +bool safe_strto64(absl::string_view str, int64_t* value); // Convert strings to unsigned 64bit integer values. // Leading and trailing spaces are allowed. // Return false with overflow or invalid input. -bool safe_strtou64(StringPiece str, uint64_t* value); +bool safe_strtou64(absl::string_view str, uint64_t* value); // Convert strings to floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. -bool safe_strtof(StringPiece str, float* value); +bool safe_strtof(absl::string_view str, float* value); // Convert strings to double precision floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. -bool safe_strtod(StringPiece str, double* value); +bool safe_strtod(absl::string_view str, double* value); -inline bool ProtoParseNumeric(StringPiece s, int32_t* value) { +inline bool ProtoParseNumeric(absl::string_view s, int32_t* value) { return safe_strto32(s, value); } -inline bool ProtoParseNumeric(StringPiece s, uint32_t* value) { +inline bool ProtoParseNumeric(absl::string_view s, uint32_t* value) { return safe_strtou32(s, value); } -inline bool ProtoParseNumeric(StringPiece s, int64_t* value) { +inline bool ProtoParseNumeric(absl::string_view s, int64_t* value) { return safe_strto64(s, value); } -inline bool ProtoParseNumeric(StringPiece s, uint64_t* value) { +inline bool ProtoParseNumeric(absl::string_view s, uint64_t* value) { return safe_strtou64(s, value); } -inline bool ProtoParseNumeric(StringPiece s, float* value) { +inline bool ProtoParseNumeric(absl::string_view s, float* value) { return safe_strtof(s, value); } -inline bool ProtoParseNumeric(StringPiece s, double* value) { +inline bool ProtoParseNumeric(absl::string_view s, double* value) { return safe_strtod(s, value); } @@ -152,7 +152,7 @@ inline bool ProtoParseNumeric(StringPiece s, double* value) { // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. template -bool SafeStringToNumeric(StringPiece s, T* value) { +bool SafeStringToNumeric(absl::string_view s, T* value) { return ProtoParseNumeric(s, value); } diff --git a/third_party/tsl/tsl/platform/numbers_test.cc b/third_party/tsl/tsl/platform/numbers_test.cc index c4386d5d01f8f1..0ce574e597dea9 100644 --- a/third_party/tsl/tsl/platform/numbers_test.cc +++ b/third_party/tsl/tsl/platform/numbers_test.cc @@ -48,7 +48,7 @@ TEST(Uint64ToHexString, Ints) { for (int delta = -1; delta <= 1; delta++) { uint64 fp = (1ull << s) + delta; char buf[kFastToBufferSize]; - StringPiece s = Uint64ToHexString(fp, buf); + absl::string_view s = Uint64ToHexString(fp, buf); uint64 fp2; EXPECT_TRUE(HexStringToUint64(s, &fp2)); EXPECT_EQ(fp, fp2) << s; @@ -145,11 +145,11 @@ TEST(safe_strto32, Int32s) { EXPECT_EQ(false, safe_strto32("-2147483649", &result)); // Check that the StringPiece's length is respected. - EXPECT_EQ(true, safe_strto32(StringPiece("123", 1), &result)); + EXPECT_EQ(true, safe_strto32(absl::string_view("123", 1), &result)); EXPECT_EQ(1, result); - EXPECT_EQ(true, safe_strto32(StringPiece(" -123", 4), &result)); + EXPECT_EQ(true, safe_strto32(absl::string_view(" -123", 4), &result)); EXPECT_EQ(-12, result); - EXPECT_EQ(false, safe_strto32(StringPiece(nullptr, 0), &result)); + EXPECT_EQ(false, safe_strto32(absl::string_view(nullptr, 0), &result)); } TEST(safe_strtou32, UInt32s) { @@ -178,11 +178,11 @@ TEST(safe_strtou32, UInt32s) { EXPECT_FALSE(safe_strtou32("-1", &result)); // Check that the StringPiece's length is respected. - EXPECT_TRUE(safe_strtou32(StringPiece("123", 1), &result)); + EXPECT_TRUE(safe_strtou32(absl::string_view("123", 1), &result)); EXPECT_EQ(1, result); - EXPECT_TRUE(safe_strtou32(StringPiece(" 123", 3), &result)); + EXPECT_TRUE(safe_strtou32(absl::string_view(" 123", 3), &result)); EXPECT_EQ(12, result); - EXPECT_FALSE(safe_strtou32(StringPiece(nullptr, 0), &result)); + EXPECT_FALSE(safe_strtou32(absl::string_view(nullptr, 0), &result)); } TEST(safe_strto64, Int64s) { @@ -214,11 +214,11 @@ TEST(safe_strto64, Int64s) { EXPECT_EQ(false, safe_strto64("-9223372036854775809", &result)); // Check that the StringPiece's length is respected. - EXPECT_EQ(true, safe_strto64(StringPiece("123", 1), &result)); + EXPECT_EQ(true, safe_strto64(absl::string_view("123", 1), &result)); EXPECT_EQ(1, result); - EXPECT_EQ(true, safe_strto64(StringPiece(" -123", 4), &result)); + EXPECT_EQ(true, safe_strto64(absl::string_view(" -123", 4), &result)); EXPECT_EQ(-12, result); - EXPECT_EQ(false, safe_strto64(StringPiece(nullptr, 0), &result)); + EXPECT_EQ(false, safe_strto64(absl::string_view(nullptr, 0), &result)); } TEST(safe_strtou64, UInt64s) { @@ -249,11 +249,11 @@ TEST(safe_strtou64, UInt64s) { EXPECT_FALSE(safe_strtou64("-1", &result)); // Check that the StringPiece's length is respected. - EXPECT_TRUE(safe_strtou64(StringPiece("123", 1), &result)); + EXPECT_TRUE(safe_strtou64(absl::string_view("123", 1), &result)); EXPECT_EQ(1, result); - EXPECT_TRUE(safe_strtou64(StringPiece(" 123", 3), &result)); + EXPECT_TRUE(safe_strtou64(absl::string_view(" 123", 3), &result)); EXPECT_EQ(12, result); - EXPECT_FALSE(safe_strtou64(StringPiece(nullptr, 0), &result)); + EXPECT_FALSE(safe_strtou64(absl::string_view(nullptr, 0), &result)); } TEST(safe_strtof, Float) { diff --git a/third_party/tsl/tsl/platform/path.cc b/third_party/tsl/tsl/platform/path.cc index 580aacde900c1a..1d808f122eee76 100644 --- a/third_party/tsl/tsl/platform/path.cc +++ b/third_party/tsl/tsl/platform/path.cc @@ -46,10 +46,10 @@ namespace { const char kPathSep[] = "/"; } // namespace -string JoinPathImpl(std::initializer_list paths) { +string JoinPathImpl(std::initializer_list paths) { string result; - for (StringPiece path : paths) { + for (absl::string_view path : paths) { if (path.empty()) continue; if (result.empty()) { @@ -73,8 +73,9 @@ string JoinPathImpl(std::initializer_list paths) { // no "/" in the path, the first part of the output is the scheme and host, and // the second is the path. If the only "/" in the path is the first character, // it is included in the first part of the output. -std::pair SplitPath(StringPiece uri) { - StringPiece scheme, host, path; +std::pair SplitPath( + absl::string_view uri) { + absl::string_view scheme, host, path; ParseURI(uri, &scheme, &host, &path); auto pos = path.rfind('/'); @@ -82,58 +83,60 @@ std::pair SplitPath(StringPiece uri) { if (pos == StringPiece::npos) pos = path.rfind('\\'); #endif // Handle the case with no '/' in 'path'. - if (pos == StringPiece::npos) - return std::make_pair(StringPiece(uri.data(), host.end() - uri.begin()), - path); + if (pos == absl::string_view::npos) + return std::make_pair( + absl::string_view(uri.data(), host.end() - uri.begin()), path); // Handle the case with a single leading '/' in 'path'. if (pos == 0) return std::make_pair( - StringPiece(uri.data(), path.begin() + 1 - uri.begin()), - StringPiece(path.data() + 1, path.size() - 1)); + absl::string_view(uri.data(), path.begin() + 1 - uri.begin()), + absl::string_view(path.data() + 1, path.size() - 1)); return std::make_pair( - StringPiece(uri.data(), path.begin() + pos - uri.begin()), - StringPiece(path.data() + pos + 1, path.size() - (pos + 1))); + absl::string_view(uri.data(), path.begin() + pos - uri.begin()), + absl::string_view(path.data() + pos + 1, path.size() - (pos + 1))); } // Return the parts of the basename of path, split on the final ".". // If there is no "." in the basename or "." is the final character in the // basename, the second value will be empty. -std::pair SplitBasename(StringPiece path) { +std::pair SplitBasename( + absl::string_view path) { path = Basename(path); auto pos = path.rfind('.'); - if (pos == StringPiece::npos) - return std::make_pair(path, StringPiece(path.data() + path.size(), 0)); + if (pos == absl::string_view::npos) + return std::make_pair(path, + absl::string_view(path.data() + path.size(), 0)); return std::make_pair( - StringPiece(path.data(), pos), - StringPiece(path.data() + pos + 1, path.size() - (pos + 1))); + absl::string_view(path.data(), pos), + absl::string_view(path.data() + pos + 1, path.size() - (pos + 1))); } } // namespace internal -bool IsAbsolutePath(StringPiece path) { +bool IsAbsolutePath(absl::string_view path) { return !path.empty() && path[0] == '/'; } -StringPiece Dirname(StringPiece path) { +absl::string_view Dirname(absl::string_view path) { return internal::SplitPath(path).first; } -StringPiece Basename(StringPiece path) { +absl::string_view Basename(absl::string_view path) { return internal::SplitPath(path).second; } -StringPiece Extension(StringPiece path) { +absl::string_view Extension(absl::string_view path) { return internal::SplitBasename(path).second; } -StringPiece BasenamePrefix(StringPiece path) { +absl::string_view BasenamePrefix(absl::string_view path) { return internal::SplitBasename(path).first; } -string CleanPath(StringPiece unclean_path) { +string CleanPath(absl::string_view unclean_path) { string path(unclean_path); const char* src = path.c_str(); string::iterator dst = path.begin(); @@ -214,8 +217,8 @@ string CleanPath(StringPiece unclean_path) { return path; } -void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host, - StringPiece* path) { +void ParseURI(absl::string_view uri, absl::string_view* scheme, + absl::string_view* host, absl::string_view* path) { // 0. Parse scheme // Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]* // TODO(keveman): Allow "+" and "-" in the scheme. @@ -228,8 +231,8 @@ void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host, .OneLiteral("://") .GetResult(&uri, scheme)) { // If there's no scheme, assume the entire string is a path. - *scheme = StringPiece(uri.data(), 0); - *host = StringPiece(uri.data(), 0); + *scheme = absl::string_view(uri.data(), 0); + *host = absl::string_view(uri.data(), 0); *path = uri; return; } @@ -238,7 +241,7 @@ void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host, if (!strings::Scanner(uri).ScanUntil('/').GetResult(&uri, host)) { // No path, so the rest of the URI is the host. *host = uri; - *path = StringPiece(); // empty path + *path = absl::string_view(); // empty path return; } @@ -246,7 +249,8 @@ void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host, *path = uri; } -string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) { +string CreateURI(absl::string_view scheme, absl::string_view host, + absl::string_view path) { if (scheme.empty()) { return string(path); } @@ -352,8 +356,8 @@ string GetTempFilename(const string& extension) { namespace { // This is private to the file, because it's possibly too limited to be useful // externally. -bool StartsWithSegment(tsl::StringPiece path, tsl::StringPiece segment) { - return tsl::str_util::StartsWith(path, segment) && +bool StartsWithSegment(absl::string_view path, absl::string_view segment) { + return absl::StartsWith(path, segment) && (path.size() == segment.size() || path.at(segment.size()) == internal::kPathSep[0]); } @@ -385,9 +389,9 @@ bool GetTestUndeclaredOutputsDir(string* dir) { return true; } -bool ResolveTestPrefixes(tsl::StringPiece path, string& resolved_path) { - constexpr tsl::StringPiece kTestWorkspaceSegment = "TEST_WORKSPACE"; - constexpr tsl::StringPiece kOutputDirSegment = "TEST_UNDECLARED_OUTPUTS_DIR"; +bool ResolveTestPrefixes(absl::string_view path, string& resolved_path) { + constexpr absl::string_view kTestWorkspaceSegment = "TEST_WORKSPACE"; + constexpr absl::string_view kOutputDirSegment = "TEST_UNDECLARED_OUTPUTS_DIR"; if (StartsWithSegment(path, kTestWorkspaceSegment)) { if (!GetTestWorkspaceDir(&resolved_path)) { diff --git a/third_party/tsl/tsl/platform/path.h b/third_party/tsl/tsl/platform/path.h index f0a5b87d135c2a..dd5567a3792e6c 100644 --- a/third_party/tsl/tsl/platform/path.h +++ b/third_party/tsl/tsl/platform/path.h @@ -24,7 +24,7 @@ limitations under the License. namespace tsl { namespace io { namespace internal { -std::string JoinPathImpl(std::initializer_list paths); +std::string JoinPathImpl(std::initializer_list paths); } // Utility routines for processing filenames @@ -51,24 +51,24 @@ std::string JoinPath(const T&... args) { #endif /* SWIG */ // Return true if path is absolute. -bool IsAbsolutePath(tsl::StringPiece path); +bool IsAbsolutePath(absl::string_view path); // Returns the part of the path before the final "/". If there is a single // leading "/" in the path, the result will be the leading "/". If there is // no "/" in the path, the result is the empty prefix of the input. -tsl::StringPiece Dirname(tsl::StringPiece path); +absl::string_view Dirname(absl::string_view path); // Returns the part of the path after the final "/". If there is no // "/" in the path, the result is the same as the input. -tsl::StringPiece Basename(tsl::StringPiece path); +absl::string_view Basename(absl::string_view path); // Returns the part of the basename of path after the final ".". If // there is no "." in the basename, the result is empty. -tsl::StringPiece Extension(tsl::StringPiece path); +absl::string_view Extension(absl::string_view path); // Returns the part of the basename of path before the final ".". If // there is no "." in the basename, the result is empty. -tsl::StringPiece BasenamePrefix(tsl::StringPiece path); +absl::string_view BasenamePrefix(absl::string_view path); // Returns the largest common subpath of `paths`. // @@ -86,7 +86,7 @@ std::string CommonPathPrefix(absl::Span paths); // invoke any system calls (getcwd(2)) in order to resolve relative // paths with respect to the actual working directory. That is, this is purely // string manipulation, completely independent of process state. -std::string CleanPath(tsl::StringPiece path); +std::string CleanPath(absl::string_view path); // Populates the scheme, host, and path from a URI. scheme, host, and path are // guaranteed by this function to point into the contents of uri, even if @@ -96,13 +96,13 @@ std::string CleanPath(tsl::StringPiece path); // - If the URI is invalid, scheme and host are set to empty strings and the // passed string is assumed to be a path // - If the URI omits the path (e.g. file://host), then the path is left empty. -void ParseURI(tsl::StringPiece uri, tsl::StringPiece* scheme, - tsl::StringPiece* host, tsl::StringPiece* path); +void ParseURI(absl::string_view uri, absl::string_view* scheme, + absl::string_view* host, absl::string_view* path); // Creates a URI from a scheme, host, and path. If the scheme is empty, we just // return the path. -std::string CreateURI(tsl::StringPiece scheme, tsl::StringPiece host, - tsl::StringPiece path); +std::string CreateURI(absl::string_view scheme, absl::string_view host, + absl::string_view path); // Creates a temporary file name with an extension. std::string GetTempFilename(const std::string& extension); @@ -124,7 +124,7 @@ bool GetTestUndeclaredOutputsDir(std::string* dir); // // Currently the TEST_WORKSPACE and the TEST_UNDECLARED_OUTPUTS_DIR prefixes can // be resolved. -bool ResolveTestPrefixes(tsl::StringPiece path, std::string& resolved_path); +bool ResolveTestPrefixes(absl::string_view path, std::string& resolved_path); // Appends `.exe` if `PLATFORM_WINDOWS` is defined. [[maybe_unused]] std::string& AppendDotExeIfWindows(std::string& path); diff --git a/third_party/tsl/tsl/platform/path_test.cc b/third_party/tsl/tsl/platform/path_test.cc index 306470e0141c55..ec43b631cf61cb 100644 --- a/third_party/tsl/tsl/platform/path_test.cc +++ b/third_party/tsl/tsl/platform/path_test.cc @@ -160,7 +160,7 @@ TEST(PathTest, CommonPathPrefix) { } TEST(PathTest, GetTestWorkspaceDir) { - constexpr tsl::StringPiece kOriginalValue = "original value"; + constexpr absl::string_view kOriginalValue = "original value"; std::string dir; dir = kOriginalValue; @@ -193,7 +193,7 @@ TEST(PathTest, GetTestWorkspaceDir) { } TEST(PathTest, GetTestUndeclaredOutputsDir) { - constexpr tsl::StringPiece kOriginalValue = "original value"; + constexpr absl::string_view kOriginalValue = "original value"; std::string dir; dir = kOriginalValue; @@ -211,7 +211,7 @@ TEST(PathTest, GetTestUndeclaredOutputsDir) { } TEST(PathTest, ResolveTestPrefixesKeepsThePathUnchanged) { - constexpr tsl::StringPiece kOriginalValue = "original value"; + constexpr absl::string_view kOriginalValue = "original value"; std::string resolved_path; resolved_path = kOriginalValue; @@ -232,7 +232,7 @@ TEST(PathTest, ResolveTestPrefixesKeepsThePathUnchanged) { } TEST(PathTest, ResolveTestPrefixesCanResolveTestWorkspace) { - constexpr tsl::StringPiece kOriginalValue = "original value"; + constexpr absl::string_view kOriginalValue = "original value"; std::string resolved_path; tsl::setenv("TEST_SRCDIR", "/repo/src", /*overwrite=*/true); @@ -260,7 +260,7 @@ TEST(PathTest, ResolveTestPrefixesCanResolveTestWorkspace) { } TEST(PathTest, ResolveTestPrefixesCannotResolveTestWorkspace) { - constexpr tsl::StringPiece kOriginalValue = "original value"; + constexpr absl::string_view kOriginalValue = "original value"; std::string resolved_path; tsl::unsetenv("TEST_SRCDIR"); @@ -272,7 +272,7 @@ TEST(PathTest, ResolveTestPrefixesCannotResolveTestWorkspace) { } TEST(PathTest, ResolveTestPrefixesCanResolveTestUndeclaredOutputsDir) { - constexpr tsl::StringPiece kOriginalValue = "original value"; + constexpr absl::string_view kOriginalValue = "original value"; std::string resolved_path; tsl::setenv("TEST_UNDECLARED_OUTPUTS_DIR", "/test/outputs", @@ -305,7 +305,7 @@ TEST(PathTest, ResolveTestPrefixesCanResolveTestUndeclaredOutputsDir) { } TEST(PathTest, ResolveTestPrefixesCannotResolveTestUndeclaredOutputsDir) { - constexpr tsl::StringPiece kOriginalValue = "original value"; + constexpr absl::string_view kOriginalValue = "original value"; std::string resolved_path; tsl::unsetenv("TEST_UNDECLARED_OUTPUTS_DIR"); diff --git a/third_party/tsl/tsl/platform/ram_file_system.h b/third_party/tsl/tsl/platform/ram_file_system.h index 245eacfe465daa..861b0666648266 100644 --- a/third_party/tsl/tsl/platform/ram_file_system.h +++ b/third_party/tsl/tsl/platform/ram_file_system.h @@ -49,12 +49,12 @@ class RamRandomAccessFile : public RandomAccessFile, public WritableFile { : name_(name), data_(cord) {} ~RamRandomAccessFile() override {} - absl::Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { *result = name_; return absl::OkStatus(); } - absl::Status Read(uint64 offset, size_t n, StringPiece* result, + absl::Status Read(uint64 offset, size_t n, absl::string_view* result, char* scratch) const override { if (offset >= data_->size()) { return errors::OutOfRange(""); @@ -65,7 +65,7 @@ class RamRandomAccessFile : public RandomAccessFile, public WritableFile { auto end = data_->begin() + offset + left; std::copy(start, end, scratch); - *result = StringPiece(scratch, left); + *result = absl::string_view(scratch, left); // In case of a partial read, we must still fill `result`, but also return // OutOfRange. @@ -75,7 +75,7 @@ class RamRandomAccessFile : public RandomAccessFile, public WritableFile { return absl::OkStatus(); } - absl::Status Append(StringPiece data) override { + absl::Status Append(absl::string_view data) override { data_->append(data.data(), data.size()); return absl::OkStatus(); } diff --git a/third_party/tsl/tsl/platform/resource.h b/third_party/tsl/tsl/platform/resource.h index 8d96cf336007b2..19a567076f5dce 100644 --- a/third_party/tsl/tsl/platform/resource.h +++ b/third_party/tsl/tsl/platform/resource.h @@ -25,7 +25,7 @@ namespace tsl { // ResourceTagger objects should only be allocated on the stack. class ResourceTagger { public: - ResourceTagger(StringPiece key, StringPiece value); + ResourceTagger(absl::string_view key, absl::string_view value); ~ResourceTagger(); // Do not allow copying or moving ResourceTagger diff --git a/third_party/tsl/tsl/platform/retrying_file_system.h b/third_party/tsl/tsl/platform/retrying_file_system.h index a64ecc20e960ff..1eb8da393d3eb5 100644 --- a/third_party/tsl/tsl/platform/retrying_file_system.h +++ b/third_party/tsl/tsl/platform/retrying_file_system.h @@ -190,11 +190,11 @@ class RetryingRandomAccessFile : public RandomAccessFile { const RetryConfig& retry_config) : base_file_(std::move(base_file)), retry_config_(retry_config) {} - absl::Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { return base_file_->Name(result); } - absl::Status Read(uint64 offset, size_t n, StringPiece* result, + absl::Status Read(uint64 offset, size_t n, absl::string_view* result, char* scratch) const override { return RetryingUtils::CallWithRetries( [this, offset, n, result, scratch]() { @@ -219,7 +219,7 @@ class RetryingWritableFile : public WritableFile { Close().IgnoreError(); } - absl::Status Append(StringPiece data) override { + absl::Status Append(absl::string_view data) override { return RetryingUtils::CallWithRetries( [this, &data]() { return base_file_->Append(data); }, retry_config_); } @@ -231,7 +231,7 @@ class RetryingWritableFile : public WritableFile { return RetryingUtils::CallWithRetries( [this]() { return base_file_->Flush(); }, retry_config_); } - absl::Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { return base_file_->Name(result); } absl::Status Sync() override { diff --git a/third_party/tsl/tsl/platform/retrying_file_system_test.cc b/third_party/tsl/tsl/platform/retrying_file_system_test.cc index 522c59f565e0b1..8477cdb353e21f 100644 --- a/third_party/tsl/tsl/platform/retrying_file_system_test.cc +++ b/third_party/tsl/tsl/platform/retrying_file_system_test.cc @@ -62,10 +62,10 @@ class MockCallSequence { class MockRandomAccessFile : public RandomAccessFile { public: explicit MockRandomAccessFile(const ExpectedCalls& calls) : calls_(calls) {} - absl::Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { return calls_.ConsumeNextCall("Name"); } - absl::Status Read(uint64 offset, size_t n, StringPiece* result, + absl::Status Read(uint64 offset, size_t n, absl::string_view* result, char* scratch) const override { return calls_.ConsumeNextCall("Read"); } @@ -77,12 +77,12 @@ class MockRandomAccessFile : public RandomAccessFile { class MockWritableFile : public WritableFile { public: explicit MockWritableFile(const ExpectedCalls& calls) : calls_(calls) {} - absl::Status Append(StringPiece data) override { + absl::Status Append(absl::string_view data) override { return calls_.ConsumeNextCall("Append"); } absl::Status Close() override { return calls_.ConsumeNextCall("Close"); } absl::Status Flush() override { return calls_.ConsumeNextCall("Flush"); } - absl::Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { return calls_.ConsumeNextCall("Name"); } absl::Status Sync() override { return calls_.ConsumeNextCall("Sync"); } @@ -220,7 +220,7 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_ImmediateSuccess) { fs.NewRandomAccessFile("filename.txt", nullptr, &random_access_file)); // Use it and check the results. - StringPiece result; + absl::string_view result; TF_EXPECT_OK(random_access_file->Name(&result)); EXPECT_EQ(result, ""); @@ -252,7 +252,7 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_SuccessWith3rdTry) { fs.NewRandomAccessFile("filename.txt", nullptr, &random_access_file)); // Use it and check the results. - StringPiece result; + absl::string_view result; char scratch[10]; TF_EXPECT_OK(random_access_file->Read(0, 10, &result, scratch)); } @@ -278,7 +278,7 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_AllRetriesFailed) { fs.NewRandomAccessFile("filename.txt", nullptr, &random_access_file)); // Use it and check the results. - StringPiece result; + absl::string_view result; char scratch[10]; const auto& status = random_access_file->Read(0, 10, &result, scratch); EXPECT_TRUE(absl::StrContains(status.message(), "Retriable error #10")) @@ -309,7 +309,7 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_NoRetriesForSomeErrors) { fs.NewRandomAccessFile("filename.txt", nullptr, &random_access_file)); // Use it and check the results. - StringPiece result; + absl::string_view result; char scratch[10]; EXPECT_EQ("Failed precondition", random_access_file->Read(0, 10, &result, scratch).message()); @@ -337,7 +337,7 @@ TEST(RetryingFileSystemTest, NewWritableFile_ImmediateSuccess) { std::unique_ptr writable_file; TF_EXPECT_OK(fs.NewWritableFile("filename.txt", nullptr, &writable_file)); - StringPiece result; + absl::string_view result; TF_EXPECT_OK(writable_file->Name(&result)); EXPECT_EQ(result, ""); diff --git a/third_party/tsl/tsl/platform/scanner.cc b/third_party/tsl/tsl/platform/scanner.cc index e5e74032270d9b..fe208678dcefd4 100644 --- a/third_party/tsl/tsl/platform/scanner.cc +++ b/third_party/tsl/tsl/platform/scanner.cc @@ -41,7 +41,8 @@ void Scanner::ScanUntilImpl(char end_ch, bool escaped) { } } -bool Scanner::GetResult(StringPiece* remaining, StringPiece* capture) { +bool Scanner::GetResult(absl::string_view* remaining, + absl::string_view* capture) { if (error_) { return false; } @@ -50,7 +51,7 @@ bool Scanner::GetResult(StringPiece* remaining, StringPiece* capture) { } if (capture != nullptr) { const char* end = capture_end_ == nullptr ? cur_.data() : capture_end_; - *capture = StringPiece(capture_start_, end - capture_start_); + *capture = absl::string_view(capture_start_, end - capture_start_); } return true; } diff --git a/third_party/tsl/tsl/platform/scanner.h b/third_party/tsl/tsl/platform/scanner.h index 2a53d57320cbe5..d8be6caade08c3 100644 --- a/third_party/tsl/tsl/platform/scanner.h +++ b/third_party/tsl/tsl/platform/scanner.h @@ -63,7 +63,9 @@ class Scanner { RANGLE, }; - explicit Scanner(StringPiece source) : cur_(source) { RestartCapture(); } + explicit Scanner(absl::string_view source) : cur_(source) { + RestartCapture(); + } // Consume the next character of the given class from input. If the next // character is not in the class, then GetResult will ultimately return false. @@ -77,15 +79,15 @@ class Scanner { // Consume the next s.size() characters of the input, if they match . If // they don't match , this is a no-op. - Scanner& ZeroOrOneLiteral(StringPiece s) { - str_util::ConsumePrefix(&cur_, s); + Scanner& ZeroOrOneLiteral(absl::string_view s) { + absl::ConsumePrefix(&cur_, s); return *this; } // Consume the next s.size() characters of the input, if they match . If // they don't match , then GetResult will ultimately return false. - Scanner& OneLiteral(StringPiece s) { - if (!str_util::ConsumePrefix(&cur_, s)) { + Scanner& OneLiteral(absl::string_view s) { + if (!absl::ConsumePrefix(&cur_, s)) { error_ = true; } return *this; @@ -161,8 +163,8 @@ class Scanner { // Returns true if the input string successfully matched. When true is // returned, the remaining string is returned in and the captured // string returned in , if non-NULL. - bool GetResult(StringPiece* remaining = nullptr, - StringPiece* capture = nullptr); + bool GetResult(absl::string_view* remaining = nullptr, + absl::string_view* capture = nullptr); private: void ScanUntilImpl(char end_ch, bool escaped); @@ -230,7 +232,7 @@ class Scanner { return false; } - StringPiece cur_; + absl::string_view cur_; const char* capture_start_ = nullptr; const char* capture_end_ = nullptr; bool error_ = false; diff --git a/third_party/tsl/tsl/platform/scanner_test.cc b/third_party/tsl/tsl/platform/scanner_test.cc index e05ad7121b524b..36681fa0496ff5 100644 --- a/third_party/tsl/tsl/platform/scanner_test.cc +++ b/third_party/tsl/tsl/platform/scanner_test.cc @@ -36,7 +36,7 @@ class ScannerTest : public ::testing::Test { }; TEST_F(ScannerTest, Any) { - StringPiece remaining, match; + absl::string_view remaining, match; EXPECT_TRUE(Scanner(" horse0123") .Any(Scanner::SPACE) .Any(Scanner::DIGIT) @@ -63,7 +63,7 @@ TEST_F(ScannerTest, Any) { } TEST_F(ScannerTest, AnySpace) { - StringPiece remaining, match; + absl::string_view remaining, match; EXPECT_TRUE(Scanner(" a b ") .AnySpace() .One(Scanner::LETTER) @@ -74,7 +74,7 @@ TEST_F(ScannerTest, AnySpace) { } TEST_F(ScannerTest, AnyEscapedNewline) { - StringPiece remaining, match; + absl::string_view remaining, match; EXPECT_TRUE(Scanner("\\\n") .Any(Scanner::LETTER_DIGIT_UNDERSCORE) .GetResult(&remaining, &match)); @@ -83,7 +83,7 @@ TEST_F(ScannerTest, AnyEscapedNewline) { } TEST_F(ScannerTest, AnyEmptyString) { - StringPiece remaining, match; + absl::string_view remaining, match; EXPECT_TRUE(Scanner("") .Any(Scanner::LETTER_DIGIT_UNDERSCORE) .GetResult(&remaining, &match)); @@ -99,7 +99,7 @@ TEST_F(ScannerTest, Eos) { } TEST_F(ScannerTest, Many) { - StringPiece remaining, match; + absl::string_view remaining, match; EXPECT_TRUE(Scanner("abc").Many(Scanner::LETTER).GetResult()); EXPECT_FALSE(Scanner("0").Many(Scanner::LETTER).GetResult()); EXPECT_FALSE(Scanner("").Many(Scanner::LETTER).GetResult()); @@ -115,7 +115,7 @@ TEST_F(ScannerTest, Many) { } TEST_F(ScannerTest, One) { - StringPiece remaining, match; + absl::string_view remaining, match; EXPECT_TRUE(Scanner("abc").One(Scanner::LETTER).GetResult()); EXPECT_FALSE(Scanner("0").One(Scanner::LETTER).GetResult()); EXPECT_FALSE(Scanner("").One(Scanner::LETTER).GetResult()); @@ -137,7 +137,7 @@ TEST_F(ScannerTest, OneLiteral) { } TEST_F(ScannerTest, ScanUntil) { - StringPiece remaining, match; + absl::string_view remaining, match; EXPECT_TRUE(Scanner(R"(' \1 \2 \3 \' \\'rest)") .OneLiteral("'") .ScanUntil('\'') @@ -164,7 +164,7 @@ TEST_F(ScannerTest, ScanUntil) { } TEST_F(ScannerTest, ScanEscapedUntil) { - StringPiece remaining, match; + absl::string_view remaining, match; EXPECT_TRUE(Scanner(R"(' \1 \2 \3 \' \\'rest)") .OneLiteral("'") .ScanEscapedUntil('\'') @@ -184,7 +184,7 @@ TEST_F(ScannerTest, ScanEscapedUntil) { } TEST_F(ScannerTest, ZeroOrOneLiteral) { - StringPiece remaining, match; + absl::string_view remaining, match; EXPECT_TRUE( Scanner("abc").ZeroOrOneLiteral("abC").GetResult(&remaining, &match)); EXPECT_EQ("abc", remaining); @@ -205,7 +205,7 @@ TEST_F(ScannerTest, ZeroOrOneLiteral) { // Test output of GetResult (including the forms with optional params), // and that it can be called multiple times. TEST_F(ScannerTest, CaptureAndGetResult) { - StringPiece remaining, match; + absl::string_view remaining, match; Scanner scan(" first second"); EXPECT_TRUE(scan.Any(Scanner::SPACE) @@ -238,7 +238,7 @@ TEST_F(ScannerTest, CaptureAndGetResult) { // Tests that if StopCapture is not called, then calling GetResult, then // scanning more, then GetResult again will update the capture. TEST_F(ScannerTest, MultipleGetResultExtendsCapture) { - StringPiece remaining, match; + absl::string_view remaining, match; Scanner scan("one2three"); EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match)); @@ -255,8 +255,8 @@ TEST_F(ScannerTest, MultipleGetResultExtendsCapture) { TEST_F(ScannerTest, FailedMatchDoesntChangeResult) { // A failed match doesn't change pointers passed to GetResult. Scanner scan("name"); - StringPiece remaining = "rem"; - StringPiece match = "match"; + absl::string_view remaining = "rem"; + absl::string_view match = "match"; EXPECT_FALSE(scan.One(Scanner::SPACE).GetResult(&remaining, &match)); EXPECT_EQ("rem", remaining); EXPECT_EQ("match", match); @@ -265,8 +265,8 @@ TEST_F(ScannerTest, FailedMatchDoesntChangeResult) { TEST_F(ScannerTest, DefaultCapturesAll) { // If RestartCapture() is not called, the whole string is used. Scanner scan("a b"); - StringPiece remaining = "rem"; - StringPiece match = "match"; + absl::string_view remaining = "rem"; + absl::string_view match = "match"; EXPECT_TRUE(scan.Any(Scanner::LETTER) .AnySpace() .Any(Scanner::LETTER) diff --git a/third_party/tsl/tsl/platform/status_test.cc b/third_party/tsl/tsl/platform/status_test.cc index b95de35e181be1..6d9948fa68d99b 100644 --- a/third_party/tsl/tsl/platform/status_test.cc +++ b/third_party/tsl/tsl/platform/status_test.cc @@ -157,7 +157,7 @@ TEST(Status, ErrorStatusForEachPayloadIteratesOverAll) { s.SetPayload("key3", absl::Cord("value3")); std::unordered_map payloads; - s.ForEachPayload([&payloads](StringPiece key, const absl::Cord& value) { + s.ForEachPayload([&payloads](absl::string_view key, const absl::Cord& value) { payloads[std::string(key)] = value; }); @@ -174,7 +174,7 @@ TEST(Status, OkStatusForEachPayloadNoIteration) { s.SetPayload("key3", absl::Cord("value3")); std::unordered_map payloads; - s.ForEachPayload([&payloads](StringPiece key, const absl::Cord& value) { + s.ForEachPayload([&payloads](absl::string_view key, const absl::Cord& value) { payloads[std::string(key)] = value; }); diff --git a/third_party/tsl/tsl/platform/str_util.cc b/third_party/tsl/tsl/platform/str_util.cc index 23de45139c2dcf..19dfb640cb375e 100644 --- a/third_party/tsl/tsl/platform/str_util.cc +++ b/third_party/tsl/tsl/platform/str_util.cc @@ -26,28 +26,28 @@ limitations under the License. namespace tsl { namespace str_util { -size_t RemoveLeadingWhitespace(StringPiece* text) { +size_t RemoveLeadingWhitespace(absl::string_view* text) { absl::string_view new_text = absl::StripLeadingAsciiWhitespace(*text); size_t count = text->size() - new_text.size(); *text = new_text; return count; } -size_t RemoveTrailingWhitespace(StringPiece* text) { +size_t RemoveTrailingWhitespace(absl::string_view* text) { absl::string_view new_text = absl::StripTrailingAsciiWhitespace(*text); size_t count = text->size() - new_text.size(); *text = new_text; return count; } -size_t RemoveWhitespaceContext(StringPiece* text) { +size_t RemoveWhitespaceContext(absl::string_view* text) { absl::string_view new_text = absl::StripAsciiWhitespace(*text); size_t count = text->size() - new_text.size(); *text = new_text; return count; } -bool ConsumeLeadingDigits(StringPiece* s, uint64_t* val) { +bool ConsumeLeadingDigits(absl::string_view* s, uint64_t* val) { const char* p = s->data(); const char* limit = p + s->size(); uint64_t v = 0; @@ -72,7 +72,7 @@ bool ConsumeLeadingDigits(StringPiece* s, uint64_t* val) { } } -bool ConsumeNonWhitespace(StringPiece* s, StringPiece* val) { +bool ConsumeNonWhitespace(absl::string_view* s, absl::string_view* val) { const char* p = s->data(); const char* limit = p + s->size(); while (p < limit) { @@ -82,27 +82,27 @@ bool ConsumeNonWhitespace(StringPiece* s, StringPiece* val) { } const size_t n = p - s->data(); if (n > 0) { - *val = StringPiece(s->data(), n); + *val = absl::string_view(s->data(), n); s->remove_prefix(n); return true; } else { - *val = StringPiece(); + *val = absl::string_view(); return false; } } -void TitlecaseString(string* s, StringPiece delimiters) { +void TitlecaseString(string* s, absl::string_view delimiters) { bool upper = true; for (string::iterator ss = s->begin(); ss != s->end(); ++ss) { if (upper) { *ss = toupper(*ss); } - upper = (delimiters.find(*ss) != StringPiece::npos); + upper = (delimiters.find(*ss) != absl::string_view::npos); } } -string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub, - bool replace_all) { +string StringReplace(absl::string_view s, absl::string_view oldsub, + absl::string_view newsub, bool replace_all) { // TODO(jlebar): We could avoid having to shift data around in the string if // we had a StringPiece::find() overload that searched for a StringPiece. string res(s); @@ -128,7 +128,7 @@ size_t Strnlen(const char* str, const size_t string_max_len) { return len; } -string ArgDefCase(StringPiece s) { +string ArgDefCase(absl::string_view s) { const size_t n = s.size(); // Compute the size of resulting string. diff --git a/third_party/tsl/tsl/platform/str_util.h b/third_party/tsl/tsl/platform/str_util.h index 9d1ba40009398c..685583faeb9670 100644 --- a/third_party/tsl/tsl/platform/str_util.h +++ b/third_party/tsl/tsl/platform/str_util.h @@ -67,26 +67,26 @@ inline void StripTrailingWhitespace(std::string* s) { // Removes leading ascii_isspace() characters. // Returns number of characters removed. -size_t RemoveLeadingWhitespace(StringPiece* text); +size_t RemoveLeadingWhitespace(absl::string_view* text); // Removes trailing ascii_isspace() characters. // Returns number of characters removed. -size_t RemoveTrailingWhitespace(StringPiece* text); +size_t RemoveTrailingWhitespace(absl::string_view* text); // Removes leading and trailing ascii_isspace() chars. // Returns number of chars removed. -size_t RemoveWhitespaceContext(StringPiece* text); +size_t RemoveWhitespaceContext(absl::string_view* text); // Consume a leading positive integer value. If any digits were // found, store the value of the leading unsigned number in "*val", // advance "*s" past the consumed number, and return true. If // overflow occurred, returns false. Otherwise, returns false. -bool ConsumeLeadingDigits(StringPiece* s, uint64_t* val); +bool ConsumeLeadingDigits(absl::string_view* s, uint64_t* val); // Consume a leading token composed of non-whitespace characters only. // If *s starts with a non-zero number of non-whitespace characters, store // them in *val, advance *s past them, and return true. Else return false. -bool ConsumeNonWhitespace(StringPiece* s, StringPiece* val); +bool ConsumeNonWhitespace(absl::string_view* s, absl::string_view* val); // If "*s" starts with "expected", consume it and return true. // Otherwise, return false. @@ -132,12 +132,12 @@ ABSL_DEPRECATE_AND_INLINE() inline std::string Uppercase(absl::string_view s) { // Capitalize first character of each word in "*s". "delimiters" is a // set of characters that can be used as word boundaries. -void TitlecaseString(std::string* s, StringPiece delimiters); +void TitlecaseString(std::string* s, absl::string_view delimiters); // Replaces the first occurrence (if replace_all is false) or all occurrences // (if replace_all is true) of oldsub in s with newsub. -std::string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub, - bool replace_all); +std::string StringReplace(absl::string_view s, absl::string_view oldsub, + absl::string_view newsub, bool replace_all); // Join functionality template @@ -156,36 +156,38 @@ std::string Join(const T& s, const char* sep, Formatter f) { } struct AllowEmpty { - bool operator()(StringPiece sp) const { return true; } + bool operator()(absl::string_view sp) const { return true; } }; struct SkipEmpty { - bool operator()(StringPiece sp) const { return !sp.empty(); } + bool operator()(absl::string_view sp) const { return !sp.empty(); } }; struct SkipWhitespace { - bool operator()(StringPiece sp) const { + bool operator()(absl::string_view sp) const { return !absl::StripTrailingAsciiWhitespace(sp).empty(); } }; // Split strings using any of the supplied delimiters. For example: // Split("a,b.c,d", ".,") would return {"a", "b", "c", "d"}. -inline std::vector Split(StringPiece text, StringPiece delims) { +inline std::vector Split(absl::string_view text, + absl::string_view delims) { return text.empty() ? std::vector() : absl::StrSplit(text, absl::ByAnyChar(delims)); } template -std::vector Split(StringPiece text, StringPiece delims, Predicate p) { +std::vector Split(absl::string_view text, absl::string_view delims, + Predicate p) { return text.empty() ? std::vector() : absl::StrSplit(text, absl::ByAnyChar(delims), p); } -inline std::vector Split(StringPiece text, char delim) { +inline std::vector Split(absl::string_view text, char delim) { return text.empty() ? std::vector() : absl::StrSplit(text, delim); } template -std::vector Split(StringPiece text, char delim, Predicate p) { +std::vector Split(absl::string_view text, char delim, Predicate p) { return text.empty() ? std::vector() : absl::StrSplit(text, delim, p); } @@ -228,7 +230,7 @@ size_t Strnlen(const char* str, const size_t string_max_len); // This method is useful for producing strings matching "[a-z][a-z0-9_]*" // as required by OpDef.ArgDef.name. The resulting string is either empty or // matches this regex. -std::string ArgDefCase(StringPiece s); +std::string ArgDefCase(absl::string_view s); } // namespace str_util } // namespace tsl diff --git a/third_party/tsl/tsl/platform/str_util_test.cc b/third_party/tsl/tsl/platform/str_util_test.cc index 5d78cb961d336e..ce52193109f721 100644 --- a/third_party/tsl/tsl/platform/str_util_test.cc +++ b/third_party/tsl/tsl/platform/str_util_test.cc @@ -22,17 +22,17 @@ limitations under the License. namespace tsl { TEST(CEscape, Basic) { - EXPECT_EQ(str_util::CEscape("hello"), "hello"); - EXPECT_EQ(str_util::CEscape("hello\n"), "hello\\n"); - EXPECT_EQ(str_util::CEscape("hello\r"), "hello\\r"); - EXPECT_EQ(str_util::CEscape("\t\r\"'"), "\\t\\r\\\"\\'"); - EXPECT_EQ(str_util::CEscape("\320hi\200"), "\\320hi\\200"); + EXPECT_EQ(absl::CEscape("hello"), "hello"); + EXPECT_EQ(absl::CEscape("hello\n"), "hello\\n"); + EXPECT_EQ(absl::CEscape("hello\r"), "hello\\r"); + EXPECT_EQ(absl::CEscape("\t\r\"'"), "\\t\\r\\\"\\'"); + EXPECT_EQ(absl::CEscape("\320hi\200"), "\\320hi\\200"); } -string ExpectCUnescapeSuccess(StringPiece source) { +string ExpectCUnescapeSuccess(absl::string_view source) { string dest; string error; - EXPECT_TRUE(str_util::CUnescape(source, &dest, &error)) << error; + EXPECT_TRUE(absl::CUnescape(source, &dest, &error)) << error; return dest; } @@ -50,103 +50,103 @@ TEST(CUnescape, HandlesCopyOnWriteStrings) { // For std::string, read and dest now share the same buffer. string error; - StringPiece source = "llohe"; + absl::string_view source = "llohe"; // CUnescape is going to write "llohe" to dest, so dest's buffer will be // reallocated, and read's buffer remains untouched. - EXPECT_TRUE(str_util::CUnescape(source, &dest, &error)); + EXPECT_TRUE(absl::CUnescape(source, &dest, &error)); EXPECT_EQ("hello", read); } TEST(StripTrailingWhitespace, Basic) { string test; test = "hello"; - str_util::StripTrailingWhitespace(&test); + absl::StripTrailingAsciiWhitespace(&test); EXPECT_EQ(test, "hello"); test = "foo "; - str_util::StripTrailingWhitespace(&test); + absl::StripTrailingAsciiWhitespace(&test); EXPECT_EQ(test, "foo"); test = " "; - str_util::StripTrailingWhitespace(&test); + absl::StripTrailingAsciiWhitespace(&test); EXPECT_EQ(test, ""); test = ""; - str_util::StripTrailingWhitespace(&test); + absl::StripTrailingAsciiWhitespace(&test); EXPECT_EQ(test, ""); test = " abc\t"; - str_util::StripTrailingWhitespace(&test); + absl::StripTrailingAsciiWhitespace(&test); EXPECT_EQ(test, " abc"); } TEST(RemoveLeadingWhitespace, Basic) { string text = " \t \n \r Quick\t"; - StringPiece data(text); + absl::string_view data(text); // check that all whitespace is removed EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 11); - EXPECT_EQ(data, StringPiece("Quick\t")); + EXPECT_EQ(data, absl::string_view("Quick\t")); // check that non-whitespace is not removed EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0); - EXPECT_EQ(data, StringPiece("Quick\t")); + EXPECT_EQ(data, absl::string_view("Quick\t")); } TEST(RemoveLeadingWhitespace, TerminationHandling) { // check termination handling string text = "\t"; - StringPiece data(text); + absl::string_view data(text); EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 1); - EXPECT_EQ(data, StringPiece("")); + EXPECT_EQ(data, absl::string_view("")); // check termination handling again EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0); - EXPECT_EQ(data, StringPiece("")); + EXPECT_EQ(data, absl::string_view("")); } TEST(RemoveTrailingWhitespace, Basic) { string text = " \t \n \r Quick \t"; - StringPiece data(text); + absl::string_view data(text); // check that all whitespace is removed EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 2); - EXPECT_EQ(data, StringPiece(" \t \n \r Quick")); + EXPECT_EQ(data, absl::string_view(" \t \n \r Quick")); // check that non-whitespace is not removed EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0); - EXPECT_EQ(data, StringPiece(" \t \n \r Quick")); + EXPECT_EQ(data, absl::string_view(" \t \n \r Quick")); } TEST(RemoveTrailingWhitespace, TerminationHandling) { // check termination handling string text = "\t"; - StringPiece data(text); + absl::string_view data(text); EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 1); - EXPECT_EQ(data, StringPiece("")); + EXPECT_EQ(data, absl::string_view("")); // check termination handling again EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0); - EXPECT_EQ(data, StringPiece("")); + EXPECT_EQ(data, absl::string_view("")); } TEST(RemoveWhitespaceContext, Basic) { string text = " \t \n \r Quick \t"; - StringPiece data(text); + absl::string_view data(text); // check that all whitespace is removed EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 13); - EXPECT_EQ(data, StringPiece("Quick")); + EXPECT_EQ(data, absl::string_view("Quick")); // check that non-whitespace is not removed EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0); - EXPECT_EQ(data, StringPiece("Quick")); + EXPECT_EQ(data, absl::string_view("Quick")); // Test empty string text = ""; data = text; EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0); - EXPECT_EQ(data, StringPiece("")); + EXPECT_EQ(data, absl::string_view("")); } -void TestConsumeLeadingDigits(StringPiece s, int64_t expected, - StringPiece remaining) { +void TestConsumeLeadingDigits(absl::string_view s, int64_t expected, + absl::string_view remaining) { uint64 v; - StringPiece input(s); + absl::string_view input(s); if (str_util::ConsumeLeadingDigits(&input, &v)) { EXPECT_EQ(v, static_cast(expected)); EXPECT_EQ(input, remaining); @@ -179,10 +179,10 @@ TEST(ConsumeLeadingDigits, Basic) { "184467440737095516159yz"); } -void TestConsumeNonWhitespace(StringPiece s, StringPiece expected, - StringPiece remaining) { - StringPiece v; - StringPiece input(s); +void TestConsumeNonWhitespace(absl::string_view s, absl::string_view expected, + absl::string_view remaining) { + absl::string_view v; + absl::string_view input(s); if (str_util::ConsumeNonWhitespace(&input, &v)) { EXPECT_EQ(v, expected); EXPECT_EQ(input, remaining); @@ -201,48 +201,48 @@ TEST(ConsumeNonWhitespace, Basic) { TEST(ConsumePrefix, Basic) { string s("abcdef"); - StringPiece input(s); - EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdefg")); + absl::string_view input(s); + EXPECT_FALSE(absl::ConsumePrefix(&input, "abcdefg")); EXPECT_EQ(input, "abcdef"); - EXPECT_FALSE(str_util::ConsumePrefix(&input, "abce")); + EXPECT_FALSE(absl::ConsumePrefix(&input, "abce")); EXPECT_EQ(input, "abcdef"); - EXPECT_TRUE(str_util::ConsumePrefix(&input, "")); + EXPECT_TRUE(absl::ConsumePrefix(&input, "")); EXPECT_EQ(input, "abcdef"); - EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdeg")); + EXPECT_FALSE(absl::ConsumePrefix(&input, "abcdeg")); EXPECT_EQ(input, "abcdef"); - EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcdef")); + EXPECT_TRUE(absl::ConsumePrefix(&input, "abcdef")); EXPECT_EQ(input, ""); input = s; - EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcde")); + EXPECT_TRUE(absl::ConsumePrefix(&input, "abcde")); EXPECT_EQ(input, "f"); } TEST(StripPrefix, Basic) { - EXPECT_EQ(str_util::StripPrefix("abcdef", "abcdefg"), "abcdef"); - EXPECT_EQ(str_util::StripPrefix("abcdef", "abce"), "abcdef"); - EXPECT_EQ(str_util::StripPrefix("abcdef", ""), "abcdef"); - EXPECT_EQ(str_util::StripPrefix("abcdef", "abcdeg"), "abcdef"); - EXPECT_EQ(str_util::StripPrefix("abcdef", "abcdef"), ""); - EXPECT_EQ(str_util::StripPrefix("abcdef", "abcde"), "f"); + EXPECT_EQ(absl::StripPrefix("abcdef", "abcdefg"), "abcdef"); + EXPECT_EQ(absl::StripPrefix("abcdef", "abce"), "abcdef"); + EXPECT_EQ(absl::StripPrefix("abcdef", ""), "abcdef"); + EXPECT_EQ(absl::StripPrefix("abcdef", "abcdeg"), "abcdef"); + EXPECT_EQ(absl::StripPrefix("abcdef", "abcdef"), ""); + EXPECT_EQ(absl::StripPrefix("abcdef", "abcde"), "f"); } TEST(JoinStrings, Basic) { std::vector s; s = {"hi"}; - EXPECT_EQ(str_util::Join(s, " "), "hi"); + EXPECT_EQ(absl::StrJoin(s, " "), "hi"); s = {"hi", "there", "strings"}; - EXPECT_EQ(str_util::Join(s, " "), "hi there strings"); + EXPECT_EQ(absl::StrJoin(s, " "), "hi there strings"); - std::vector sp; + std::vector sp; sp = {"hi"}; - EXPECT_EQ(str_util::Join(sp, ",,"), "hi"); + EXPECT_EQ(absl::StrJoin(sp, ",,"), "hi"); sp = {"hi", "there", "strings"}; - EXPECT_EQ(str_util::Join(sp, "--"), "hi--there--strings"); + EXPECT_EQ(absl::StrJoin(sp, "--"), "hi--there--strings"); } TEST(JoinStrings, Join3) { @@ -257,36 +257,35 @@ TEST(JoinStrings, Join3) { TEST(Split, Basic) { EXPECT_TRUE(str_util::Split("", ',').empty()); - EXPECT_EQ(str_util::Join(str_util::Split("a", ','), "|"), "a"); - EXPECT_EQ(str_util::Join(str_util::Split(",", ','), "|"), "|"); - EXPECT_EQ(str_util::Join(str_util::Split("a,b,c", ','), "|"), "a|b|c"); - EXPECT_EQ(str_util::Join(str_util::Split("a,,,b,,c,", ','), "|"), + EXPECT_EQ(absl::StrJoin(str_util::Split("a", ','), "|"), "a"); + EXPECT_EQ(absl::StrJoin(str_util::Split(",", ','), "|"), "|"); + EXPECT_EQ(absl::StrJoin(str_util::Split("a,b,c", ','), "|"), "a|b|c"); + EXPECT_EQ(absl::StrJoin(str_util::Split("a,,,b,,c,", ','), "|"), "a|||b||c|"); + EXPECT_EQ(absl::StrJoin(str_util::Split("a!,!b,!c,", ",!"), "|"), "a|||b||c|"); - EXPECT_EQ(str_util::Join(str_util::Split("a!,!b,!c,", ",!"), "|"), - "a|||b||c|"); - EXPECT_EQ(str_util::Join( + EXPECT_EQ(absl::StrJoin( str_util::Split("a,,,b,,c,", ',', str_util::SkipEmpty()), "|"), "a|b|c"); EXPECT_EQ( - str_util::Join( + absl::StrJoin( str_util::Split("a, ,b,,c,", ',', str_util::SkipWhitespace()), "|"), "a|b|c"); - EXPECT_EQ(str_util::Join(str_util::Split("a. !b,;c,", ".,;!", - str_util::SkipWhitespace()), - "|"), + EXPECT_EQ(absl::StrJoin(str_util::Split("a. !b,;c,", ".,;!", + str_util::SkipWhitespace()), + "|"), "a|b|c"); } TEST(Lowercase, Basic) { - EXPECT_EQ("", str_util::Lowercase("")); - EXPECT_EQ("hello", str_util::Lowercase("hello")); - EXPECT_EQ("hello world", str_util::Lowercase("Hello World")); + EXPECT_EQ("", absl::AsciiStrToLower("")); + EXPECT_EQ("hello", absl::AsciiStrToLower("hello")); + EXPECT_EQ("hello world", absl::AsciiStrToLower("Hello World")); } TEST(Uppercase, Basic) { - EXPECT_EQ("", str_util::Uppercase("")); - EXPECT_EQ("HELLO", str_util::Uppercase("hello")); - EXPECT_EQ("HELLO WORLD", str_util::Uppercase("Hello World")); + EXPECT_EQ("", absl::AsciiStrToUpper("")); + EXPECT_EQ("HELLO", absl::AsciiStrToUpper("hello")); + EXPECT_EQ("HELLO WORLD", absl::AsciiStrToUpper("Hello World")); } TEST(SnakeCase, Basic) { diff --git a/third_party/tsl/tsl/platform/strcat.cc b/third_party/tsl/tsl/platform/strcat.cc index ffdb8802e0f7c6..afa4fd5e2630fe 100644 --- a/third_party/tsl/tsl/platform/strcat.cc +++ b/third_party/tsl/tsl/platform/strcat.cc @@ -43,7 +43,7 @@ AlphaNum::AlphaNum(Hex hex) { value >>= 4; mask >>= 4; } while (mask != 0); - piece_ = StringPiece(writer, end - writer); + piece_ = absl::string_view(writer, end - writer); } // ---------------------------------------------------------------------- @@ -180,14 +180,14 @@ void STLStringResizeUninitializedAmortized(string_type *s, size_t new_size) { namespace internal { // Do not call directly - these are not part of the public API. -string CatPieces(std::initializer_list pieces) { +string CatPieces(std::initializer_list pieces) { size_t total_size = 0; - for (const StringPiece piece : pieces) total_size += piece.size(); + for (const absl::string_view piece : pieces) total_size += piece.size(); string result(total_size, '\0'); char *const begin = &*result.begin(); char *out = begin; - for (const StringPiece piece : pieces) { + for (const absl::string_view piece : pieces) { const size_t this_size = piece.size(); memcpy(out, piece.data(), this_size); out += this_size; @@ -203,10 +203,11 @@ string CatPieces(std::initializer_list pieces) { #define DCHECK_NO_OVERLAP(dest, src) \ DCHECK_GE(uintptr_t((src).data() - (dest).data()), uintptr_t((dest).size())) -void AppendPieces(string *result, std::initializer_list pieces) { +void AppendPieces(string *result, + std::initializer_list pieces) { size_t old_size = result->size(); size_t total_size = old_size; - for (const StringPiece piece : pieces) { + for (const absl::string_view piece : pieces) { DCHECK_NO_OVERLAP(*result, piece); total_size += piece.size(); } @@ -214,7 +215,7 @@ void AppendPieces(string *result, std::initializer_list pieces) { char *const begin = &*result->begin(); char *out = begin + old_size; - for (const StringPiece piece : pieces) { + for (const absl::string_view piece : pieces) { const size_t this_size = piece.size(); memcpy(out, piece.data(), this_size); out += this_size; diff --git a/third_party/tsl/tsl/platform/strcat.h b/third_party/tsl/tsl/platform/strcat.h index 198465b4cc0515..d552a8a8977baf 100644 --- a/third_party/tsl/tsl/platform/strcat.h +++ b/third_party/tsl/tsl/platform/strcat.h @@ -122,7 +122,8 @@ class AlphaNum { AlphaNum(Hex hex); // NOLINT(runtime/explicit) AlphaNum(const char *c_str) : piece_(c_str) {} // NOLINT(runtime/explicit) - AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit) + AlphaNum(const absl::string_view &pc) + : piece_(pc) {} // NOLINT(runtime/explicit) AlphaNum(const std::string &str) // NOLINT(runtime/explicit) : piece_(str) {} AlphaNum(const tstring &str) // NOLINT(runtime/explicit) @@ -131,12 +132,12 @@ class AlphaNum { AlphaNum(const std::basic_string, A> &str) : piece_(str) {} // NOLINT(runtime/explicit) - StringPiece::size_type size() const { return piece_.size(); } + absl::string_view::size_type size() const { return piece_.size(); } const char *data() const { return piece_.data(); } - StringPiece Piece() const { return piece_; } + absl::string_view Piece() const { return piece_; } private: - StringPiece piece_; + absl::string_view piece_; char digits_[kFastToBufferSize]; // Use ":" not ':' @@ -181,8 +182,9 @@ std::string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, namespace internal { // Do not call directly - this is not part of the public API. -std::string CatPieces(std::initializer_list pieces); -void AppendPieces(std::string *dest, std::initializer_list pieces); +std::string CatPieces(std::initializer_list pieces); +void AppendPieces(std::string *dest, + std::initializer_list pieces); } // namespace internal diff --git a/third_party/tsl/tsl/platform/strcat_test.cc b/third_party/tsl/tsl/platform/strcat_test.cc index 752b5fd554cbcf..d62fdb60361e9a 100644 --- a/third_party/tsl/tsl/platform/strcat_test.cc +++ b/third_party/tsl/tsl/platform/strcat_test.cc @@ -120,7 +120,7 @@ TEST(StrCat, Basics) { string strs[] = {"Hello", "Cruel", "World"}; - StringPiece pieces[] = {"Hello", "Cruel", "World"}; + absl::string_view pieces[] = {"Hello", "Cruel", "World"}; const char *c_strs[] = {"Hello", "Cruel", "World"}; @@ -244,7 +244,7 @@ TEST(StrAppend, Basics) { string strs[] = {"Hello", "Cruel", "World"}; - StringPiece pieces[] = {"Hello", "Cruel", "World"}; + absl::string_view pieces[] = {"Hello", "Cruel", "World"}; const char *c_strs[] = {"Hello", "Cruel", "World"}; diff --git a/third_party/tsl/tsl/platform/stringpiece_test.cc b/third_party/tsl/tsl/platform/stringpiece_test.cc index fafed5a3276370..b7a46ed5d7b149 100644 --- a/third_party/tsl/tsl/platform/stringpiece_test.cc +++ b/third_party/tsl/tsl/platform/stringpiece_test.cc @@ -25,24 +25,24 @@ TEST(StringPiece, Ctor) { { // const char* without size. const char* hello = "hello"; - StringPiece s20(hello); + absl::string_view s20(hello); EXPECT_TRUE(s20.data() == hello); EXPECT_EQ(5, s20.size()); // const char* with size. - StringPiece s21(hello, 4); + absl::string_view s21(hello, 4); EXPECT_TRUE(s21.data() == hello); EXPECT_EQ(4, s21.size()); // Not recommended, but valid C++ - StringPiece s22(hello, 6); + absl::string_view s22(hello, 6); EXPECT_TRUE(s22.data() == hello); EXPECT_EQ(6, s22.size()); } { string hola = "hola"; - StringPiece s30(hola); + absl::string_view s30(hola); EXPECT_TRUE(s30.data() == hola.data()); EXPECT_EQ(4, s30.size()); @@ -50,15 +50,15 @@ TEST(StringPiece, Ctor) { hola.push_back('\0'); hola.append("h2"); hola.push_back('\0'); - StringPiece s31(hola); + absl::string_view s31(hola); EXPECT_TRUE(s31.data() == hola.data()); EXPECT_EQ(8, s31.size()); } } TEST(StringPiece, ConversionToString) { - EXPECT_EQ("", string(StringPiece(""))); - EXPECT_EQ("foo", string(StringPiece("foo"))); + EXPECT_EQ("", string(absl::string_view(""))); + EXPECT_EQ("foo", string(absl::string_view("foo"))); } } // namespace tsl diff --git a/third_party/tsl/tsl/platform/tracing.cc b/third_party/tsl/tsl/platform/tracing.cc index e818e32ac47af4..a6cfcd036077fd 100644 --- a/third_party/tsl/tsl/platform/tracing.cc +++ b/third_party/tsl/tsl/platform/tracing.cc @@ -51,7 +51,7 @@ uint64 GetUniqueArg() { return unique_arg.fetch_add(1, std::memory_order_relaxed); } -uint64 GetArgForName(StringPiece name) { +uint64 GetArgForName(absl::string_view name) { return Hash64(name.data(), name.size()); } diff --git a/third_party/tsl/tsl/platform/tracing.h b/third_party/tsl/tsl/platform/tracing.h index a90474099f67c3..51268eb7134cea 100644 --- a/third_party/tsl/tsl/platform/tracing.h +++ b/third_party/tsl/tsl/platform/tracing.h @@ -77,7 +77,7 @@ inline const EventCollector* GetEventCollector(EventCategory category) { uint64 GetUniqueArg(); // Returns an id for name to pass to RecordEvent/ScopedRegion. -uint64 GetArgForName(StringPiece name); +uint64 GetArgForName(absl::string_view name); // Records an atomic event through the currently registered EventCollector. inline void RecordEvent(EventCategory category, uint64 arg) { @@ -113,7 +113,7 @@ class ScopedRegion { // Same as ScopedRegion(category, GetArgForName(name)), but faster if // EventCollector::IsEnabled() returns false. - ScopedRegion(EventCategory category, StringPiece name) + ScopedRegion(EventCategory category, absl::string_view name) : collector_(GetEventCollector(category)) { if (collector_) { collector_->StartRegion(GetArgForName(name)); diff --git a/third_party/tsl/tsl/platform/tstring.h b/third_party/tsl/tsl/platform/tstring.h index 0bdd9f52a76cc6..97028f6ca261b8 100644 --- a/third_party/tsl/tsl/platform/tstring.h +++ b/third_party/tsl/tsl/platform/tstring.h @@ -100,7 +100,7 @@ class tstring { tstring(const char* str, size_t len); tstring(const char* str); // NOLINT TODO(b/147740521): Make explicit. tstring(size_t n, char c); - explicit tstring(const StringPiece str); + explicit tstring(const absl::string_view str); #ifdef PLATFORM_GOOGLE explicit tstring(const absl::Cord& cord); #endif // PLATFORM_GOOGLE @@ -119,7 +119,7 @@ class tstring { tstring& operator=(const std::string& str); tstring& operator=(const char* str); tstring& operator=(char ch); - tstring& operator=(const StringPiece str); + tstring& operator=(const absl::string_view str); #ifdef PLATFORM_GOOGLE tstring& operator=(const absl::Cord& cord); #endif // PLATFORM_GOOGLE @@ -143,7 +143,7 @@ class tstring { // TODO(b/147740521): Make explicit. operator std::string() const; // NOLINT // TODO(b/147740521): Make explicit. - operator StringPiece() const; // NOLINT + operator absl::string_view() const; // NOLINT #ifdef PLATFORM_GOOGLE template ::value, @@ -191,7 +191,7 @@ class tstring { // View Assignment tstring& assign_as_view(const tstring& str); tstring& assign_as_view(const std::string& str); - tstring& assign_as_view(const StringPiece str); + tstring& assign_as_view(const absl::string_view str); tstring& assign_as_view(const char* str, size_t len); tstring& assign_as_view(const char* str); @@ -245,7 +245,7 @@ inline tstring::tstring(size_t n, char c) { inline tstring::tstring(const std::string& str) : tstring(str.data(), str.size()) {} -inline tstring::tstring(const StringPiece str) +inline tstring::tstring(const absl::string_view str) : tstring(str.data(), str.size()) {} #ifdef PLATFORM_GOOGLE @@ -301,7 +301,7 @@ inline tstring& tstring::operator=(char c) { return *this; } -inline tstring& tstring::operator=(const StringPiece str) { +inline tstring& tstring::operator=(const absl::string_view str) { TF_TString_Copy(&tstr_, str.data(), str.size()); return *this; @@ -377,15 +377,15 @@ inline tstring::operator std::string() const { return std::string(data(), size()); } -inline tstring::operator StringPiece() const { - return StringPiece(data(), size()); +inline tstring::operator absl::string_view() const { + return absl::string_view(data(), size()); } #ifdef PLATFORM_GOOGLE template ::value, T>::type*> inline tstring::operator T() const { - return T(StringPiece(*this)); + return T(absl::string_view(*this)); } #endif // PLATFORM_GOOGLE @@ -477,7 +477,7 @@ inline tstring& tstring::assign_as_view(const std::string& str) { return *this; } -inline tstring& tstring::assign_as_view(const StringPiece str) { +inline tstring& tstring::assign_as_view(const absl::string_view str) { assign_as_view(str.data(), str.size()); return *this; diff --git a/third_party/tsl/tsl/platform/tstring_test.cc b/third_party/tsl/tsl/platform/tstring_test.cc index 633bcb275f9ecc..78263471b61073 100644 --- a/third_party/tsl/tsl/platform/tstring_test.cc +++ b/third_party/tsl/tsl/platform/tstring_test.cc @@ -38,7 +38,7 @@ TEST(TF_TStringTest, Construction) { tstring s11("a\0a", 3); tstring s12(kLongString); tstring s13(3, 'b'); - tstring s14(tsl::StringPiece("hi")); + tstring s14(absl::string_view("hi")); tstring s15(std::string("bye")); EXPECT_EQ("", s10); @@ -126,7 +126,7 @@ TEST(TF_TStringTest, Assignment) { EXPECT_EQ(tstring::Type::SMALL, s33.type()); EXPECT_EQ(1, s33.size()); - s32 = tsl::StringPiece(kLongString); + s32 = absl::string_view(kLongString); EXPECT_EQ(kLongString, s32); EXPECT_EQ(tstring::Type::LARGE, s32.type()); @@ -135,7 +135,7 @@ TEST(TF_TStringTest, Assignment) { // LARGE -> SMALL but still LARGE s32.resize(TF_TString_SmallCapacity * 2); - EXPECT_EQ(tsl::StringPiece(kLongString, TF_TString_SmallCapacity * 2), s32); + EXPECT_EQ(absl::string_view(kLongString, TF_TString_SmallCapacity * 2), s32); EXPECT_EQ(tstring::Type::LARGE, s32.type()); EXPECT_EQ(TF_TString_SmallCapacity * 2, s32.size()); @@ -174,7 +174,7 @@ TEST(TF_TStringTest, Assignment) { EXPECT_EQ(2, s33.size()); - s32.assign_as_view(tsl::StringPiece(kLongString)); + s32.assign_as_view(absl::string_view(kLongString)); EXPECT_EQ(tstring::Type::VIEW, s32.type()); EXPECT_EQ(kLongString, s32.c_str()); @@ -255,7 +255,7 @@ TEST(TF_TStringTest, Comparison) { TEST(TF_TStringTest, Conversion) { tstring s50(kLongString); std::string s51(s50); - tsl::StringPiece s52(s50); + absl::string_view s52(s50); EXPECT_EQ(kLongString, s51); EXPECT_EQ(kLongStringLen, s51.size()); EXPECT_EQ(kLongString, s52); From 7beebdb87b8aba1206657fc1b18d0e0ab8d77a2e Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 29 Jul 2024 01:57:13 -0700 Subject: [PATCH 224/376] [XLA:GPU] Adjust HloInstructionAdaptor::GetUsers() We already enforce that GetOperands() only considers operands that belong to the parent HloFusionAdaptor. To be consistent, we should also do this for users. PiperOrigin-RevId: 657115232 --- xla/service/gpu/hlo_traversal.cc | 14 ++++++++------ xla/service/gpu/hlo_traversal_test.cc | 8 +++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/xla/service/gpu/hlo_traversal.cc b/xla/service/gpu/hlo_traversal.cc index b529997f15f962..4394226dfadc0b 100644 --- a/xla/service/gpu/hlo_traversal.cc +++ b/xla/service/gpu/hlo_traversal.cc @@ -40,18 +40,20 @@ namespace { template void ResolveUsers(const HloInstruction* value, const HloInstruction* user, - const HloFusionAdaptor& fusion_adaptor, F&& fn) { + const HloFusionAdaptor& fusion_adaptor, F&& add_user) { if (user->opcode() == HloOpcode::kTuple && user->IsRoot()) { if (auto* fusion = user->parent()->FusionInstruction()) { // Skip through the tuple -> get-tuple-element ops and directly go to the // "real" users. for (const auto* gte : fusion->users()) { if (gte->opcode() != HloOpcode::kGetTupleElement) { - fn(gte); + if (fusion_adaptor.ContainsInstruction(value)) { + add_user(gte); + } continue; } for (const auto* gte_user : gte->users()) { - ResolveUsers(gte, gte_user, fusion_adaptor, fn); + ResolveUsers(gte, gte_user, fusion_adaptor, add_user); } } } @@ -59,10 +61,10 @@ void ResolveUsers(const HloInstruction* value, const HloInstruction* user, user->opcode() == HloOpcode::kFusion) { auto* param = user->fused_parameter(user->operand_index(value)); for (const auto* param_user : param->users()) { - fn(param_user); + add_user(param_user); } - } else { - fn(user); + } else if (fusion_adaptor.ContainsInstruction(user)) { + add_user(user); } } diff --git a/xla/service/gpu/hlo_traversal_test.cc b/xla/service/gpu/hlo_traversal_test.cc index 43c7a9e75dc04d..ee3a4b7ad1239f 100644 --- a/xla/service/gpu/hlo_traversal_test.cc +++ b/xla/service/gpu/hlo_traversal_test.cc @@ -128,8 +128,7 @@ TEST_F(HloTraversalTest, AdaptorUsers) { HloInstructionAdaptor add{*module->GetComputationWithName("fused_computation") ->GetInstructionWithName("add"), fusion_adaptor1.get()}; - EXPECT_THAT(add.GetUsers(), ElementsAre(InstructionAdaptorName("add.1"), - InstructionAdaptorName("mul"), + EXPECT_THAT(add.GetUsers(), ElementsAre(InstructionAdaptorName("mul"), InstructionAdaptorName("res"))); auto fusion_adaptor2 = HloFusionAdaptor::ForInstruction( @@ -145,7 +144,7 @@ TEST_F(HloTraversalTest, AdaptorUsers) { *module->GetComputationWithName("fused_computation_1") ->GetInstructionWithName("neg.1"), fusion_adaptor2.get()}; - EXPECT_THAT(neg.GetUsers(), ElementsAre(InstructionAdaptorName("exp.1"))); + EXPECT_TRUE(neg.GetUsers().empty()); } TEST_F(HloTraversalTest, TraverseFusionConsumerFirst) { @@ -377,8 +376,7 @@ TEST_F(HloTraversalTest, FuseFusionConsumer) { ->GetInstructionWithName("reduce.1"), fusion.get()); - EXPECT_THAT(reduce_1.GetUsers(), - ElementsAre(InstructionAdaptorName("fusion.2"))); + EXPECT_TRUE(reduce_1.GetUsers().empty()); std::vector nodes; std::vector params; From f230ce20f0447b893d9365888d46e517df39ac16 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 29 Jul 2024 02:25:51 -0700 Subject: [PATCH 225/376] Reverts 2c1812d5958aa285e69fa0e54502eb103d4374eb PiperOrigin-RevId: 657122248 --- xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 750fb560cfa7fc..c35ea757728c8c 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -273,7 +273,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_use_shardy(false); - opts.set_xla_gpu_shard_autotuning(true); + opts.set_xla_gpu_shard_autotuning(false); opts.set_xla_syntax_sugar_async_ops(false); From 8197129176e611720e3817e2d9b9f4dbf853c21b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 29 Jul 2024 02:48:27 -0700 Subject: [PATCH 226/376] Automated Code Change PiperOrigin-RevId: 657127496 --- xla/service/gpu/command_buffer_scheduling.cc | 2 +- xla/service/gpu/gpu_conv_padding_legalization.cc | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/command_buffer_scheduling.cc b/xla/service/gpu/command_buffer_scheduling.cc index d81046b9534331..d113a6b5d01ec8 100644 --- a/xla/service/gpu/command_buffer_scheduling.cc +++ b/xla/service/gpu/command_buffer_scheduling.cc @@ -719,7 +719,7 @@ absl::StatusOr CommandBufferScheduling::Run( } absl::flat_hash_set legacy_custom_call_targets; - for (auto target : + for (const auto& target : debug_options.legacy_command_buffer_custom_call_targets()) { legacy_custom_call_targets.insert(target); } diff --git a/xla/service/gpu/gpu_conv_padding_legalization.cc b/xla/service/gpu/gpu_conv_padding_legalization.cc index bbe037280494b6..0b55f7d264ff00 100644 --- a/xla/service/gpu/gpu_conv_padding_legalization.cc +++ b/xla/service/gpu/gpu_conv_padding_legalization.cc @@ -146,6 +146,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, // Compute the shape and padding config of the pad to be inserted. PaddingConfig padding_config; + padding_config.mutable_dimensions()->Reserve( + kernel->shape().dimensions_size()); for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) { padding_config.add_dimensions(); } From 7c3e32bee40380a9709b5c17dfa2eaa1cfba30e6 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Mon, 29 Jul 2024 04:13:43 -0700 Subject: [PATCH 227/376] Create new builder & verifier for IndexingMapAttr PiperOrigin-RevId: 657147509 --- .../gpu/fusions/mlir/ir/xla_gpu_attrs.cc | 27 +++++++++++++++++++ .../gpu/fusions/mlir/ir/xla_gpu_attrs.td | 4 +++ 2 files changed, 31 insertions(+) diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc index ad31f42c64bc84..d3829056de5dc3 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" @@ -163,5 +164,31 @@ void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { printer << ">"; } +IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context, + const IndexingMap& indexing_map) { + llvm::SmallVector> constraints; + for (auto& constraint : indexing_map.GetConstraints()) { + constraints.push_back({constraint.first, constraint.second}); + } + return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(), + indexing_map.GetRangeVars(), constraints); +} + +mlir::LogicalResult IndexingMapAttr::verify( + mlir::function_ref emitError, + mlir::AffineMap map, ArrayRef dim_vars, + ArrayRef range_vars, + ArrayRef> constraints) { + if (map.getNumDims() != dim_vars.size()) { + return emitError() + << "dim size must match the number of dimensions in the affine map"; + } + if (map.getNumSymbols() != range_vars.size()) { + return emitError() + << "range size must match the number of symbols in the affine map"; + } + return mlir::success(); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td index cf137686b23e4e..8c8f98c05e2737 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td @@ -51,6 +51,10 @@ def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { XLAGPU_RangeVarsParameter:$range_vars, XLAGPU_ConstraintsParameter:$constraints); let hasCustomAssemblyFormat = 1; + let builders = [ + AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>, + ]; + let genVerifyDecl = 1; } #endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_ATTRS From c6866a1e0f1afd87857c83a5a4f3f786301f3e70 Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Mon, 29 Jul 2024 04:38:23 -0700 Subject: [PATCH 228/376] Integrate LLVM at llvm/llvm-project@99bb9a719cec Updates LLVM usage to match [99bb9a719cec](https://github.com/llvm/llvm-project/commit/99bb9a719cec) PiperOrigin-RevId: 657153082 --- third_party/llvm/generated.patch | 37 +++++++++++++++++++ third_party/llvm/workspace.bzl | 4 +- third_party/shardy/workspace.bzl | 4 +- .../tsl/third_party/llvm/generated.patch | 37 +++++++++++++++++++ .../tsl/third_party/llvm/workspace.bzl | 4 +- 5 files changed, 80 insertions(+), 6 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 506f5632703a41..21d431fd26ed47 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,4 +1,41 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp +--- a/clang/lib/Sema/SemaTemplateDeduction.cpp ++++ b/clang/lib/Sema/SemaTemplateDeduction.cpp +@@ -951,9 +951,11 @@ + + // Skip over the pack elements that were expanded into separate arguments. + // If we partially expanded, this is the number of partial arguments. ++ // FIXME: `&& FixedNumExpansions` is a workaround for UB described in ++ // https://github.com/llvm/llvm-project/issues/100095 + if (IsPartiallyExpanded) + PackElements += NumPartialPackArgs; +- else if (IsExpanded) ++ else if (IsExpanded && FixedNumExpansions) + PackElements += *FixedNumExpansions; + + for (auto &Pack : Packs) { +diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/pr100095.cpp b/clang/test/SemaCXX/pr100095.cpp +--- a/clang/test/SemaCXX/pr100095.cpp ++++ b/clang/test/SemaCXX/pr100095.cpp +@@ -0,0 +1,17 @@ ++// RUN: %clang_cc1 -fsyntax-only -std=c++11 %s ++// XFAIL: asserts ++ ++template struct Pair; ++template struct Tuple { ++ template Tuple(_Up); ++}; ++template struct StatusOr; ++template using ElementType = int; ++template ++using Key = Tuple...>; ++template ++StatusOr>> Parser(); ++struct Helper { Helper(Tuple<>, Tuple<>, int, int); }; ++struct D : Helper { ++ D(Key<> f, int n, int e) : Helper(f, Parser<>, n, e) {} ++}; diff -ruN --strip-trailing-cr a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 45f9cf544dc10c..072de9d9c6420f 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "51d4980a133db12888207698e39c469cb7055cac" - LLVM_SHA256 = "ee34426de8adf8408a610d0072e82b50bad0adac2c009f1f20072d626c0b876e" + LLVM_COMMIT = "99bb9a719cec9513e72ad275c1c0302b76b6c408" + LLVM_SHA256 = "af060bd4edd9340fd0b90ddd246c78e87dd374d5998a4c154f31d11f8888a076" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 3c78b846facc61..c82f3275766f90 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "effc9ac0716b25861f7deaea91aafaa93515a1aa" - SHARDY_SHA256 = "cce9c625b2ce107c2ab19e811059bf1d3da0160fdbe418778658a8f19fef211a" + SHARDY_COMMIT = "8f92b38a2400ce5dc72f97067b02c635ed4f3d00" + SHARDY_SHA256 = "3d91370627e81ce5285e5a6ec0d6dbefc786ae32f6d1ebcb4aa61fd247378b91" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index 506f5632703a41..21d431fd26ed47 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1,4 +1,41 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp +--- a/clang/lib/Sema/SemaTemplateDeduction.cpp ++++ b/clang/lib/Sema/SemaTemplateDeduction.cpp +@@ -951,9 +951,11 @@ + + // Skip over the pack elements that were expanded into separate arguments. + // If we partially expanded, this is the number of partial arguments. ++ // FIXME: `&& FixedNumExpansions` is a workaround for UB described in ++ // https://github.com/llvm/llvm-project/issues/100095 + if (IsPartiallyExpanded) + PackElements += NumPartialPackArgs; +- else if (IsExpanded) ++ else if (IsExpanded && FixedNumExpansions) + PackElements += *FixedNumExpansions; + + for (auto &Pack : Packs) { +diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/pr100095.cpp b/clang/test/SemaCXX/pr100095.cpp +--- a/clang/test/SemaCXX/pr100095.cpp ++++ b/clang/test/SemaCXX/pr100095.cpp +@@ -0,0 +1,17 @@ ++// RUN: %clang_cc1 -fsyntax-only -std=c++11 %s ++// XFAIL: asserts ++ ++template struct Pair; ++template struct Tuple { ++ template Tuple(_Up); ++}; ++template struct StatusOr; ++template using ElementType = int; ++template ++using Key = Tuple...>; ++template ++StatusOr>> Parser(); ++struct Helper { Helper(Tuple<>, Tuple<>, int, int); }; ++struct D : Helper { ++ D(Key<> f, int n, int e) : Helper(f, Parser<>, n, e) {} ++}; diff -ruN --strip-trailing-cr a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 45f9cf544dc10c..072de9d9c6420f 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "51d4980a133db12888207698e39c469cb7055cac" - LLVM_SHA256 = "ee34426de8adf8408a610d0072e82b50bad0adac2c009f1f20072d626c0b876e" + LLVM_COMMIT = "99bb9a719cec9513e72ad275c1c0302b76b6c408" + LLVM_SHA256 = "af060bd4edd9340fd0b90ddd246c78e87dd374d5998a4c154f31d11f8888a076" tf_http_archive( name = name, From c8080173233a00e584449af7e03ccd6d535c36e9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 29 Jul 2024 05:15:40 -0700 Subject: [PATCH 229/376] Integrate Triton up to fd691c67 (https://github.com/openai/triton/commits/fd691c67ac20958a67693358186d877790f5f48f) PiperOrigin-RevId: 657162423 --- third_party/triton/llvm_integration/series.bzl | 1 - third_party/triton/workspace.bzl | 4 ++-- xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc | 3 +++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 9d0e1204ba527f..656b9c894904d8 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,6 +8,5 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton/llvm_integration:cl656020169.patch", # Add new patches just above this line ] diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 29bee9bb79295b..d1a4940f567dd9 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl652414579" - TRITON_SHA256 = "06afd4b310b0f8e48432564917f730f20f1f0c69bc15c8f114d7a1b1cb7215af" + TRITON_COMMIT = "cl655158651" + TRITON_SHA256 = "ac136693d2aeae327896d33e1a4de4852f25c1c2cdca49f85a2b9ac8b6d03b44" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index 93d3d8b76480be..2c12aafb9ac536 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -106,6 +106,9 @@ absl::Status CreateTritonPipeline( // @triton//:third_party/amd/backend/compiler.py pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( ccRocm.gfx_version())); + const int custom_lds_size = 0; + pm.addPass(mlir::triton::AMD::createOptimizeLdsUsagePass(ccRocm.gfx_version(), + custom_lds_size)); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); From 134fcafdd6c0d98c6316ede09e8f4482b67a881a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 29 Jul 2024 05:22:59 -0700 Subject: [PATCH 230/376] Automated Code Change PiperOrigin-RevId: 657164404 --- third_party/tsl/tsl/platform/base64.cc | 7 ++++--- third_party/tsl/tsl/platform/base64.h | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/third_party/tsl/tsl/platform/base64.cc b/third_party/tsl/tsl/platform/base64.cc index 7cf8f2d606887f..592c473757cb07 100644 --- a/third_party/tsl/tsl/platform/base64.cc +++ b/third_party/tsl/tsl/platform/base64.cc @@ -79,7 +79,7 @@ absl::Status DecodeThreeChars(const char* codes, char* result) { } // namespace template -absl::Status Base64Decode(StringPiece data, T* decoded) { +absl::Status Base64Decode(absl::string_view data, T* decoded) { if (decoded == nullptr) { return errors::Internal("'decoded' cannot be nullptr."); } @@ -142,12 +142,13 @@ absl::Status Base64Decode(StringPiece data, T* decoded) { } template -absl::Status Base64Encode(StringPiece source, T* encoded) { +absl::Status Base64Encode(absl::string_view source, T* encoded) { return Base64Encode(source, false, encoded); } template -absl::Status Base64Encode(StringPiece source, bool with_padding, T* encoded) { +absl::Status Base64Encode(absl::string_view source, bool with_padding, + T* encoded) { const char* const base64_chars = kBase64UrlSafeChars; if (encoded == nullptr) { return errors::Internal("'encoded' cannot be nullptr."); diff --git a/third_party/tsl/tsl/platform/base64.h b/third_party/tsl/tsl/platform/base64.h index fa2ad0ad40d618..d1e213b81139db 100644 --- a/third_party/tsl/tsl/platform/base64.h +++ b/third_party/tsl/tsl/platform/base64.h @@ -27,16 +27,17 @@ namespace tsl { /// /// See https://en.wikipedia.org/wiki/Base64 template -absl::Status Base64Encode(StringPiece source, bool with_padding, T* encoded); +absl::Status Base64Encode(absl::string_view source, bool with_padding, + T* encoded); template -absl::Status Base64Encode(StringPiece source, +absl::Status Base64Encode(absl::string_view source, T* encoded); // with_padding=false. /// \brief Converts data from web-safe base64 encoding. /// /// See https://en.wikipedia.org/wiki/Base64 template -absl::Status Base64Decode(StringPiece data, T* decoded); +absl::Status Base64Decode(absl::string_view data, T* decoded); // Explicit instantiations defined in base64.cc. extern template Status Base64Decode(StringPiece data, From e291d59c3bb3b6ef007943f28c5aaf4638c0ee08 Mon Sep 17 00:00:00 2001 From: Chao Date: Mon, 29 Jul 2024 05:40:31 -0700 Subject: [PATCH 231/376] PR #15311: [ROCm] GPU/CPU unified memory for rocm Imported from GitHub PR https://github.com/openxla/xla/pull/15311 @xla-rotation Copybara import of the project: -- 2c4cee2bc335c72538261c41f485d49f1eb7c08f by Chao Chen : unified memory for rocm Merging this change closes #15311 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15311 from ROCm:ci_rocm_unify_mem 2c4cee2bc335c72538261c41f485d49f1eb7c08f PiperOrigin-RevId: 657168704 --- xla/stream_executor/rocm/rocm_driver.cc | 28 +++++++++++++++---- .../rocm/rocm_driver_wrapper.h | 1 + 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index bb8982fc3d654c..465dbbe84b2a00 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -1367,16 +1367,32 @@ struct BitPatternToValue { /* static */ void* GpuDriver::UnifiedMemoryAllocate(GpuContext* context, uint64_t bytes) { ScopedActivateContext activated{context}; - - LOG(ERROR) - << "Feature not supported on ROCm platform (UnifiedMemoryAllocate)"; - return nullptr; + hipDeviceptr_t result = 0; + // "managed" memory is visible to both CPU and GPU. + hipError_t res = wrap::hipMallocManaged(&result, bytes, hipMemAttachGlobal); + if (res != hipSuccess) { + LOG(ERROR) << "failed to alloc " << bytes + << " bytes unified memory; result: " << ToString(res); + return nullptr; + } + void* ptr = reinterpret_cast(result); + VLOG(2) << "allocated " << ptr << " for context " << context->context() + << " of " << bytes << " bytes in unified memory"; + return ptr; } /* static */ void GpuDriver::UnifiedMemoryDeallocate(GpuContext* context, void* location) { - LOG(ERROR) - << "Feature not supported on ROCm platform (UnifiedMemoryDeallocate)"; + ScopedActivateContext activation(context); + hipDeviceptr_t pointer = absl::bit_cast(location); + hipError_t res = wrap::hipFree(pointer); + if (res != hipSuccess) { + LOG(ERROR) << "failed to free unified memory at " << location + << "; result: " << ToString(res); + } else { + VLOG(2) << "deallocated unified memory at " << location << " for context " + << context->context(); + } } /* static */ void* GpuDriver::HostAllocate(GpuContext* context, diff --git a/xla/stream_executor/rocm/rocm_driver_wrapper.h b/xla/stream_executor/rocm/rocm_driver_wrapper.h index 74b18b6076c58f..d3ac52ead9f07f 100644 --- a/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -135,6 +135,7 @@ namespace wrap { __macro(hipLaunchHostFunc) \ __macro(hipLaunchKernel) \ __macro(hipMalloc) \ + __macro(hipMallocManaged) \ __macro(hipMemGetAddressRange) \ __macro(hipMemGetInfo) \ __macro(hipMemcpyDtoD) \ From eafa3204121f28b34074fd073a5ec58f552b1c7c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 29 Jul 2024 06:18:07 -0700 Subject: [PATCH 232/376] [XLA:GPU] Disable the constant folding for pad(broadcast(), constant) It does not give us significant benefits. At the same time for the big outputs like 1m parameters it needs significant compile time. PiperOrigin-RevId: 657177207 --- xla/service/hlo_constant_folding.cc | 13 +++++++++++++ xla/service/hlo_constant_folding_test.cc | 20 ++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/xla/service/hlo_constant_folding.cc b/xla/service/hlo_constant_folding.cc index e603ce771cb55e..cb2852fbb2e619 100644 --- a/xla/service/hlo_constant_folding.cc +++ b/xla/service/hlo_constant_folding.cc @@ -167,6 +167,16 @@ absl::StatusOr HloConstantFolding::Run( continue; } + if (instruction->opcode() == HloOpcode::kPad && + instruction->operand(0)->opcode() == HloOpcode::kBroadcast && + instruction->operand(1)->opcode() == HloOpcode::kConstant) { + // Reduce the compile time by skipping the constant folding of pad + // instruction with broadcast operand. With 45m shape limit the compile + // time could be more than 30 seconds. According to the current + // benchmarks it does not affect the performance. + continue; + } + // Don't constant fold unless output and operand sizes are small. if (instruction->shape().IsArray()) { int64_t elements_in_operands = 0; @@ -181,6 +191,9 @@ absl::StatusOr HloConstantFolding::Run( static const int64_t kMaximumConstantSizeElements = 45 * 1000 * 1000; if (std::max(elements_in_constant, elements_in_operands) > kMaximumConstantSizeElements) { + VLOG(2) << "Ignore constant folding: result shape size is " + << elements_in_constant << " total size of arguments is " + << elements_in_operands; continue; } } diff --git a/xla/service/hlo_constant_folding_test.cc b/xla/service/hlo_constant_folding_test.cc index 3aeca15d8ce927..255012f734a5f1 100644 --- a/xla/service/hlo_constant_folding_test.cc +++ b/xla/service/hlo_constant_folding_test.cc @@ -303,6 +303,26 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { GmockMatch(m::Pad(m::Constant(), m::Constant()))); } +TEST_F(HloConstantFoldingTest, DoesNotFoldPadBroadcast) { + const char* const kConstantFoldPadBroadcast = R"( + HloModule ConstantFoldLargePad + + ENTRY r { + a = f32[] constant(239) + broadcast_a = f32[4] broadcast(a), dimensions={} + b = f32[] constant(42) + ROOT pad = f32[8] pad(f32[4] broadcast_a, f32[] b), padding=4_0 + })"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kConstantFoldPadBroadcast)); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_FALSE(result); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Pad(m::Broadcast(), m::Constant()))); +} + TEST_F(HloConstantFoldingTest, DoesNotFoldSlicesWithLargeOperand) { const char* const kModuleStr = R"( HloModule test From bc057a22cae4dc5d0d0d99084c3d2d3159717d1f Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Mon, 29 Jul 2024 07:58:36 -0700 Subject: [PATCH 233/376] [XLA:GPU] Add comment describing CollectiveQuantizer. PiperOrigin-RevId: 657200982 --- xla/service/gpu/gpu_compiler.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 676ed086bd163f..15debb4020b854 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -934,6 +934,8 @@ absl::Status RunCollectiveOptimizationPasses( // Remove dead computations left over after ar/rs promotion. collectives_pipeline.AddPass(); + // Moves collectives' subsequent quantization before the collective to + // minimize data transfers. collectives_pipeline.AddPass(); // Remove dead computations after collective quantization. collectives_pipeline.AddPass(); From 761ed600c48f21f9278e4a14fbf724a017113395 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Mon, 29 Jul 2024 08:43:48 -0700 Subject: [PATCH 234/376] [XLA:GPU] Add pipeline parallelism tests with circular repeat PiperOrigin-RevId: 657213596 --- .../collective_pipeline_parallelism_test.cc | 337 ++++++++++++++++++ 1 file changed, 337 insertions(+) diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index 509bb9d2cfcf22..a88b1000f4737a 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -406,5 +406,342 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) { ErrorSpec{1e-5, 1e-5})); } +// Naive implementation of pipeline parallelism: +// - 4 devices +// - 4 microbatches +// - 2 circular repeat +// - no disabled collectives +// - no collective pipelining +// +// Every stage of the pipeline is a single linear layer. +XLA_TEST_F(CollectivePipelineParallelismTest, + NaiveDFSMicrobatch4CircularRepeat2Replica4) { + const absl::string_view kModuleStr = R"( + HloModule test + + get_circ_buffer_index { + offset = u32[] parameter(0) + index = u32[] parameter(1) + size = u32[] parameter(2) + t0 = u32[] add(offset, index) + t1 = u32[] divide(t0, size) + t2 = u32[] multiply(t1, size) + ROOT t4 = u32[] subtract(t0, t2) + } + + is_input_replica { + replica_id = u32[] replica-id() + c0 = u32[] constant(0) + ROOT predicate = pred[] compare(replica_id, c0), direction=EQ + } + + is_output_replica { + replica_id = u32[] replica-id() + c3 = u32[] constant(3) + ROOT predicate = pred[] compare(replica_id, c3), direction=EQ + } + + is_read_input { + is_input_replica = pred[] call(), to_apply=is_input_replica + i = u32[] parameter(0) + c4 = u32[] constant(4) + is_input_iteration = pred[] compare(i, c4), direction=LT + ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) + } + + while_condition { + tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0) + i = u32[] get-tuple-element(tuple), index=4 + n = u32[] constant(11) + ROOT predicate = pred[] compare(i, n), direction=LT + } + + while_body { + tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0) + weights = f32[16,16] get-tuple-element(tuple), index=0 + input = f32[4,16] get-tuple-element(tuple), index=1 + output = f32[4,16] get-tuple-element(tuple), index=2 + tmp = f32[16] get-tuple-element(tuple), index=3 + i = u32[] get-tuple-element(tuple), index=4 + + c1 = u32[] constant(1) + c0 = u32[] constant(0) + c4 = u32[] constant(4) + + input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index + input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), + dynamic_slice_sizes={1,16} + input_slice_ = f32[16] reshape(input_slice) + + prev_stage_slice = f32[16] collective-permute(tmp), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + + is_read_input = pred[] call(i), to_apply=is_read_input + compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice) + + compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index + output_slice = f32[1,16] reshape(compute_out) + output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index, + c0) + + i_ = add(i, c1) + + ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) + tuple(weights, input, output_, compute_out, i_) + } + + ENTRY main { + weights = f32[16,16] parameter(0) + input = f32[4,16] parameter(1) + + cf0 = f32[] constant(0) + output = f32[4,16] broadcast(cf0), dimensions={} + tmp = f32[16] broadcast(cf0), dimensions={} + c0 = u32[] constant(0) + + tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights, + input, output, tmp, c0) + tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple), + condition=while_condition, body=while_body + + ROOT output_ = f32[4,16] get-tuple-element(tuple_), index=2 + } + )"; + + const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // This pipeline consists of a total of 8 layers (2 per replica), each of + // which is a single linear layer. We assign the weights to the replicas such + // that the layers scale the input data by 1.0, 2.0, 3.0 and 4.0 in the first + // and second cycle. The combined effect is to scale the input data by 576.0 + // (24.0 * 24.0). + const int64_t kInputSize = 16; + Literal weights_r0 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 1.0); + Literal weights_r1 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 2.0); + Literal weights_r2 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 3.0); + Literal weights_r3 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 4.0); + + // Only the first replica holds the input to the pipeline in this naive + // implementation. The remaining replicas get zero/dummy input. + const int64_t kMicrobatches = 4; + Literal real_input = + LiteralUtil::CreateFingerprintMatixR2(kMicrobatches, kInputSize); + Literal fake_input = + LiteralUtil::CreateFull({kMicrobatches, kInputSize}, 0.0); + + // Check pipeline output for last replica. + // The combined effect of the pipeline is to scale the input data by 576.0 + // (24.0 * 24.0). + const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0 * 1.0 * 2.0 * 3.0 * 4.0; + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + kMicrobatches, kInputSize, /*scale=*/kExpectedFactor); + std::vector> args = {{&weights_r0, &real_input}, + {&weights_r1, &fake_input}, + {&weights_r2, &fake_input}, + {&weights_r3, &fake_input}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), args, kNumReplicas, + /*run_hlo_passes=*/true)); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], + ErrorSpec{1e-5, 1e-5})); +} + +// Naive implementation if pipeline parallelism: +// - 4 devices +// - 5 microbatches +// - 2 circular repeat +// - no disabled collectives +// - no collective pipelining +// +// Every stage of the pipeline is a single linear layer. +XLA_TEST_F(CollectivePipelineParallelismTest, + NaiveDFSMicrobatch5CircularRepeat2Replica4) { + const absl::string_view kModuleStr = R"( + HloModule test + + get_circ_buffer_index { + offset = u32[] parameter(0) + index = u32[] parameter(1) + size = u32[] parameter(2) + t0 = u32[] add(offset, index) + t1 = u32[] divide(t0, size) + t2 = u32[] multiply(t1, size) + ROOT t4 = u32[] subtract(t0, t2) + } + + read_buffer { + buffer = f32[5,16] parameter(0) + offset = u32[] parameter(1) + index = u32[] parameter(2) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + slice = f32[1,16] dynamic-slice(buffer, index__, c0), + dynamic_slice_sizes={1,16} + ROOT slice_ = f32[16] reshape(slice) + } + + update_buffer { + buffer = f32[5,16] parameter(0) + update = f32[16] parameter(1) + offset = u32[] parameter(2) + index = u32[] parameter(3) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + update_ = f32[1,16] reshape(update) + ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0) + } + + is_input_replica { + replica_id = u32[] replica-id() + c0 = u32[] constant(0) + ROOT predicate = pred[] compare(replica_id, c0), direction=EQ + } + + is_output_replica { + replica_id = u32[] replica-id() + c3 = u32[] constant(3) + ROOT predicate = pred[] compare(replica_id, c3), direction=EQ + } + + is_read_input { + is_input_replica = pred[] call(), to_apply=is_input_replica + i = u32[] parameter(0) + c5 = u32[] constant(5) + is_input_iteration = pred[] compare(i, c5), direction=LT + ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) + } + + while_condition { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + parameter(0) + i = u32[] get-tuple-element(tuple), index=5 + n = u32[] constant(13) + ROOT predicate = pred[] compare(i, n), direction=LT + } + + while_body { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + parameter(0) + weights = f32[16,16] get-tuple-element(tuple), index=0 + input = f32[5,16] get-tuple-element(tuple), index=1 + output = f32[5,16] get-tuple-element(tuple), index=2 + buffer = f32[5,16] get-tuple-element(tuple), index=3 + prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4 + i = u32[] get-tuple-element(tuple), index=5 + + c0 = u32[] constant(0) + c1 = u32[] constant(1) + c2 = u32[] constant(2) + c3 = u32[] constant(3) + c5 = u32[] constant(5) + + input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index + input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), + dynamic_slice_sizes={1,16} + input_slice_ = f32[16] reshape(input_slice) + + buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer + + is_output_replica = pred[] call(), to_apply=is_output_replica + next_stage_slice = select(is_output_replica, buffer_slice, + prev_iteration_compute_out) + + prev_stage_slice = f32[16] collective-permute(next_stage_slice), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + + is_read_input = pred[] call(i), to_apply=is_read_input + compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice) + + compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer + + buffer_ = f32[5,16] call(buffer, compute_out, c0, i), to_apply=update_buffer + + i_ = add(i, c1) + + ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output_, buffer_, compute_out, i_) + } + + ENTRY main { + weights = f32[16,16] parameter(0) + input = f32[5,16] parameter(1) + + cf0 = f32[] constant(0) + output = f32[5,16] broadcast(cf0), dimensions={} + buffer = f32[5,16] broadcast(cf0), dimensions={} + prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={} + c0 = u32[] constant(0) + + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output, buffer, prev_iteration_compute_out, c0) + tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + while(tuple), condition=while_condition, body=while_body + + ROOT output_ = f32[5,16] get-tuple-element(tuple_), index=2 + } + )"; + + const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // This pipeline consists of a total of 8 layers (2 per replica), each of + // which is a single linear layer. We assign the weights to the replicas such + // that the layers scale the input data by 1.0, 2.0, 3.0 and 4.0 in the first + // and second cycle. The combined effect is to scale the input data by 576.0 + // (24.0 * 24.0). + const int64_t kInputSize = 16; + Literal weights_r0 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 1.0); + Literal weights_r1 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 2.0); + Literal weights_r2 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 3.0); + Literal weights_r3 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 4.0); + + // Only the first replica holds the input to the pipeline in this naive + // implementation. The remaining replicas get zero/dummy input. + const int64_t kMicrobatches = 5; + Literal real_input = + LiteralUtil::CreateFingerprintMatixR2(kMicrobatches, kInputSize); + Literal fake_input = + LiteralUtil::CreateFull({kMicrobatches, kInputSize}, 0.0); + + // Check pipeline output for last replica. + // The combined effect of the pipeline is to scale the input data by 576.0 + // (24.0 * 24.0). + const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0 * 1.0 * 2.0 * 3.0 * 4.0; + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + kMicrobatches, kInputSize, /*scale=*/kExpectedFactor); + std::vector> args = {{&weights_r0, &real_input}, + {&weights_r1, &fake_input}, + {&weights_r2, &fake_input}, + {&weights_r3, &fake_input}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), args, kNumReplicas, + /*run_hlo_passes=*/true)); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], + ErrorSpec{1e-5, 1e-5})); +} + } // namespace } // namespace xla From 0ff318582d730dcbe2afab23e56bce060c624997 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 29 Jul 2024 09:00:39 -0700 Subject: [PATCH 235/376] PR #15229: [NFC] Add a documentation page about determinism. Imported from GitHub PR https://github.com/openxla/xla/pull/15229 Copybara import of the project: -- 5a4bd2b93ebe025d5d6d56949cd25e5dbbed070e by Ilia Sergachev : [NFC] Add a documentation page about determinism. Merging this change closes #15229 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15229 from openxla:determinism_doc 5a4bd2b93ebe025d5d6d56949cd25e5dbbed070e PiperOrigin-RevId: 657217844 --- docs/_toc.yaml | 2 ++ docs/determinism.md | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 docs/determinism.md diff --git a/docs/_toc.yaml b/docs/_toc.yaml index b133e29f7c87c8..d8ef492e6c4cca 100644 --- a/docs/_toc.yaml +++ b/docs/_toc.yaml @@ -24,6 +24,8 @@ toc: path: /xla/custom_call - title: Persisted autotuning path: /xla/persisted_autotuning + - title: Determinism + path: /xla/determinism - title: XLA Tooling path: /xla/tools - title: Using LSP autocompletion diff --git a/docs/determinism.md b/docs/determinism.md new file mode 100644 index 00000000000000..d8cd934e5cb1fc --- /dev/null +++ b/docs/determinism.md @@ -0,0 +1,17 @@ +# Determinism (GPU) + +## Compilation + +XLA compilation is deterministic if +[persisted autotuning](./persisted_autotuning) is used to perform autotuning +once and avoid it in subsequent compilations. Otherwise due to fluctuations in +measurements different kernels can be picked as the fastest ones in different +compilation runs. + +## Execution + +Programs compiled by XLA can be non-deterministic on operations like scatter, +select-and-scatter, GEMMs, convolutions, multi-headed attention. The flag +`--xla_gpu_exclude_nondeterministic_ops` switches these operations to +deterministic and potentially slower implementations and makes compilation fail +on select-and-scatter which does not have a deterministic implementaiton. From b4424c695a9b8cd4f60b3d6119e6b32fd00a2052 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 29 Jul 2024 10:17:35 -0700 Subject: [PATCH 236/376] Integrate Triton up to [baa16342](https://github.com/openai/triton/commits/baa1634263394e3d91677b528ae9f6b4f27e274a) PiperOrigin-RevId: 657242760 --- .../triton/llvm_integration/cl656020169.patch | 12 ------------ third_party/triton/workspace.bzl | 4 ++-- xla/service/gpu/tests/sparse_add_layout.mlir | 4 ++-- 3 files changed, 4 insertions(+), 16 deletions(-) delete mode 100644 third_party/triton/llvm_integration/cl656020169.patch diff --git a/third_party/triton/llvm_integration/cl656020169.patch b/third_party/triton/llvm_integration/cl656020169.patch deleted file mode 100644 index 7586a90b14ccf6..00000000000000 --- a/third_party/triton/llvm_integration/cl656020169.patch +++ /dev/null @@ -1,12 +0,0 @@ -diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp ---- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp -+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp -@@ -117,7 +117,7 @@ private: - auto operands = callOp.getOperands(); - auto result = callOp.getResult(); - -- LLVM::LLVMFunctionType calleeType = callOp.getCalleeType().value(); -+ LLVM::LLVMFunctionType calleeType = callOp.getVarCalleeType().value(); - Type returnType = calleeType.getReturnType(); - - auto loc = callOp.getLoc(); diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index d1a4940f567dd9..f321b4e1a46b1c 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl655158651" - TRITON_SHA256 = "ac136693d2aeae327896d33e1a4de4852f25c1c2cdca49f85a2b9ac8b6d03b44" + TRITON_COMMIT = "cl657175856" + TRITON_SHA256 = "316f421a7d7ead2b7e5adc2e8bb68ce1a8f7809db73dbed8abd54c35bd0c1576" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/xla/service/gpu/tests/sparse_add_layout.mlir b/xla/service/gpu/tests/sparse_add_layout.mlir index 6457691c211c3e..10b3e45f7e6278 100644 --- a/xla/service/gpu/tests/sparse_add_layout.mlir +++ b/xla/service/gpu/tests/sparse_add_layout.mlir @@ -35,10 +35,10 @@ module { // CHECK-NEXT: %[[CVT:.*]] = triton_gpu.convert_layout %[[D]] // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED4x4]]> // CHECK-SAME: -> tensor<64x64xf32, #[[BLOCKED1x1]]> - // CHECK-NEXT: tt.print "" {hex = false} : %[[CVT]] + // CHECK-NEXT: tt.print "" {hex = false, isSigned = array} : %[[CVT]] // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED1x1]]> // A use with side effects so we don't DCE the whole function. - tt.print "" { hex = false } : %d : tensor<64x64xf32> + tt.print "" { hex = false, isSigned = array} : %d : tensor<64x64xf32> // CHECK-NEXT: tt.return tt.return From c4fa9f417a9a159fb659adb7b5cd4c7d2f0de09d Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 29 Jul 2024 10:36:35 -0700 Subject: [PATCH 237/376] Combine StreamExecutor::GetKernel and ::CreateKernel into a single new method ::LoadKernel. PiperOrigin-RevId: 657249905 --- xla/backends/interpreter/executor.h | 4 ++-- xla/stream_executor/cuda/cuda_executor.cc | 28 +++++++++------------- xla/stream_executor/gpu/gpu_executor.h | 6 ++--- xla/stream_executor/host/host_executor.cc | 14 ++++------- xla/stream_executor/host/host_executor.h | 6 ++--- xla/stream_executor/kernel_factory.h | 6 +---- xla/stream_executor/mock_stream_executor.h | 6 ++--- xla/stream_executor/rocm/rocm_executor.cc | 22 +++++++---------- xla/stream_executor/stream_executor.h | 14 +++-------- 9 files changed, 36 insertions(+), 70 deletions(-) diff --git a/xla/backends/interpreter/executor.h b/xla/backends/interpreter/executor.h index 8ca0cd9c357ef0..c653fc7317b595 100644 --- a/xla/backends/interpreter/executor.h +++ b/xla/backends/interpreter/executor.h @@ -86,8 +86,8 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { absl::Status Init() override { return absl::OkStatus(); } int device_ordinal() const override { return device_ordinal_; }; - absl::Status GetKernel(const MultiKernelLoaderSpec &spec, - Kernel *kernel) override { + absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec &spec) override { return absl::UnimplementedError("Not Implemented"); } absl::Status Launch(Stream *stream, const ThreadDim &thread_dims, diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 63df37d3c037d9..7f478df047be84 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -216,21 +216,19 @@ absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, "Feature not supported on CUDA platform (LoadModuleFromHsaco)"); } -absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - GpuKernel* cuda_kernel = AsGpuKernel(kernel); +absl::StatusOr> GpuExecutor::LoadKernel( + const MultiKernelLoaderSpec& spec) { + auto cuda_kernel = std::make_unique(this); CUmodule module; const std::string* kernel_name; - VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); - if (spec.has_cuda_cubin_in_memory()) { absl::MutexLock lock{&in_memory_modules_mu_}; kernel_name = &spec.cuda_cubin_in_memory().kernel_name(); const char* cubin = reinterpret_cast( spec.cuda_cubin_in_memory().cubin_bytes().data()); TF_RETURN_IF_ERROR(LoadModuleFromCuBin(cubin, &module)); - kernel_to_gpu_binary_[kernel] = cubin; + kernel_to_gpu_binary_[cuda_kernel.get()] = cubin; } else if (spec.has_cuda_ptx_in_memory()) { kernel_name = &spec.cuda_ptx_in_memory().kernel_name(); @@ -249,7 +247,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, absl::MutexLock lock{&in_memory_modules_mu_}; TF_RETURN_IF_ERROR(LoadModuleFromPtx(ptx, &module)); - kernel_to_gpu_binary_[kernel] = ptx; + kernel_to_gpu_binary_[cuda_kernel.get()] = ptx; } else if (spec.has_in_process_symbol()) { kernel_name = &spec.in_process_symbol().kernel_name(); @@ -265,7 +263,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, } else { return absl::InternalError("No method of loading CUDA kernel provided"); } - + VLOG(3) << "LoadKernel on kernel : " << *kernel_name; // If we resolved kernel from a symbol pointer, there is no need to load it // from a module, as CUDA runtime did that automatically for us. if (!spec.has_in_process_symbol()) { @@ -284,11 +282,11 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, cuda_kernel->set_arity(spec.arity()); KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); - kernel->set_name(*kernel_name); - kernel->set_args_packing(spec.kernel_args_packing()); - return absl::OkStatus(); + TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel.get(), &kernel_metadata)); + cuda_kernel->set_metadata(kernel_metadata); + cuda_kernel->set_name(*kernel_name); + cuda_kernel->set_args_packing(spec.kernel_args_packing()); + return std::move(cuda_kernel); } absl::StatusOr> @@ -793,10 +791,6 @@ absl::StatusOr> GpuExecutor::CreateStream( } } -absl::StatusOr> GpuExecutor::CreateKernel() { - return std::make_unique(this); -} - absl::StatusOr> GpuExecutor::CreateCommandBuffer( CommandBuffer::Mode mode) { VLOG(2) << "Create CUDA command buffer (CUDA graph)"; diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index c19fa1cceeba0c..13b9b944d1beb2 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -122,8 +122,8 @@ class GpuExecutor : public StreamExecutorCommon { int device_ordinal() const override { return device_ordinal_; }; - absl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) override; + absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec& spec) override; // (supported on CUDA only) void UnloadKernel(const Kernel* kernel) override; @@ -240,8 +240,6 @@ class GpuExecutor : public StreamExecutorCommon { std::optional> priority = std::nullopt) override; - absl::StatusOr> CreateKernel() override; - absl::StatusOr> CreateCommandBuffer( CommandBuffer::Mode mode) override; diff --git a/xla/stream_executor/host/host_executor.cc b/xla/stream_executor/host/host_executor.cc index ac1d22583d0fde..8d8eeb7e421de1 100644 --- a/xla/stream_executor/host/host_executor.cc +++ b/xla/stream_executor/host/host_executor.cc @@ -74,24 +74,18 @@ absl::Status HostExecutor::Init() { return absl::OkStatus(); } -absl::StatusOr> HostExecutor::CreateKernel() { - return std::make_unique(thread_pool_); -} - -absl::Status HostExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - HostKernel* host_kernel = AsHostKernel(kernel); +absl::StatusOr> HostExecutor::LoadKernel( + const MultiKernelLoaderSpec& spec) { + auto host_kernel = std::make_unique(thread_pool_); host_kernel->SetArity(spec.arity()); - VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); - for (auto& loader : KernelFunctionLoaderRegistry()) { auto loaded = loader(spec); if (!loaded.has_value()) continue; TF_ASSIGN_OR_RETURN(auto kernel_function, *std::move(loaded)); host_kernel->SetKernelFunction(std::move(kernel_function)); - return absl::OkStatus(); + return std::move(host_kernel); } return absl::InternalError("No method of loading host kernel provided"); diff --git a/xla/stream_executor/host/host_executor.h b/xla/stream_executor/host/host_executor.h index 18ec5a739faca5..4e2a2230ffbd4c 100644 --- a/xla/stream_executor/host/host_executor.h +++ b/xla/stream_executor/host/host_executor.h @@ -70,10 +70,8 @@ class HostExecutor : public StreamExecutorCommon { absl::Status Init() override; - absl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) override; - - absl::StatusOr> CreateKernel() override; + absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec& spec) override; absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const Kernel& kernel, diff --git a/xla/stream_executor/kernel_factory.h b/xla/stream_executor/kernel_factory.h index 24e594ed89d10e..17e07cd0f97950 100644 --- a/xla/stream_executor/kernel_factory.h +++ b/xla/stream_executor/kernel_factory.h @@ -22,8 +22,6 @@ limitations under the License. #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace stream_executor { @@ -33,9 +31,7 @@ class KernelFactory { // Creates kernel on a given executor from a given kernel specification. static inline absl::StatusOr> Create( StreamExecutor *executor, const MultiKernelLoaderSpec &spec) { - TF_ASSIGN_OR_RETURN(auto kernel, executor->CreateKernel()); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, kernel.get())); - return kernel; + return executor->LoadKernel(spec); } }; diff --git a/xla/stream_executor/mock_stream_executor.h b/xla/stream_executor/mock_stream_executor.h index 3787be1133b5d4..f58a553f9ebdd8 100644 --- a/xla/stream_executor/mock_stream_executor.h +++ b/xla/stream_executor/mock_stream_executor.h @@ -59,8 +59,8 @@ class MockStreamExecutor : public StreamExecutor { MockStreamExecutor() = default; MOCK_METHOD(absl::Status, Init, (), (override)); MOCK_METHOD(int, device_ordinal, (), (const, override)); - MOCK_METHOD(absl::Status, GetKernel, - (const MultiKernelLoaderSpec& spec, Kernel* kernel), (override)); + MOCK_METHOD(absl::StatusOr>, LoadKernel, + (const MultiKernelLoaderSpec& spec), (override)); MOCK_METHOD(bool, UnloadModule, (ModuleHandle module_handle), (override)); MOCK_METHOD(absl::Status, LoadModule, (const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle), @@ -124,8 +124,6 @@ class MockStreamExecutor : public StreamExecutor { MOCK_METHOD(blas::BlasSupport*, AsBlas, (), (override)); MOCK_METHOD(fft::FftSupport*, AsFft, (), (override)); MOCK_METHOD(dnn::DnnSupport*, AsDnn, (), (override)); - MOCK_METHOD(absl::StatusOr>, CreateKernel, (), - (override)); MOCK_METHOD(absl::StatusOr>, CreateCommandBuffer, (CommandBuffer::Mode mode), (override)); MOCK_METHOD(std::optional, GetAllocatorStats, (), (override)); diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 49fc9c646868ae..19a367a37ec27a 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -255,9 +255,9 @@ absl::StatusOr GpuExecutor::DelayKernelIsSupported(GpuStream* stream) { return false; } -absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - GpuKernel* rocm_kernel = AsGpuKernel(kernel); +absl::StatusOr> GpuExecutor::LoadKernel( + const MultiKernelLoaderSpec& spec) { + auto rocm_kernel = std::make_unique(this); hipModule_t module = nullptr; const std::string* kernel_name; @@ -272,7 +272,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, if (module == nullptr) { TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, &module)); } - kernel_to_gpu_binary_[kernel] = hsaco; + kernel_to_gpu_binary_[rocm_kernel.get()] = hsaco; } else if (spec.has_in_process_symbol()) { kernel_name = &spec.in_process_symbol().kernel_name(); void* symbol = spec.in_process_symbol().symbol(); @@ -310,12 +310,12 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, // unable to get kernel metadata for in-process kernel if (!spec.has_in_process_symbol()) { KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); + TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel.get(), &kernel_metadata)); + rocm_kernel->set_metadata(kernel_metadata); } - kernel->set_name(*kernel_name); - kernel->set_args_packing(spec.kernel_args_packing()); - return absl::OkStatus(); + rocm_kernel->set_name(*kernel_name); + rocm_kernel->set_args_packing(spec.kernel_args_packing()); + return std::move(rocm_kernel); } absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, @@ -669,10 +669,6 @@ absl::StatusOr> GpuExecutor::CreateStream( } } -absl::StatusOr> GpuExecutor::CreateKernel() { - return std::make_unique(this); -} - absl::StatusOr> GpuExecutor::CreateCommandBuffer( CommandBuffer::Mode mode) { VLOG(2) << "Create ROCm command buffer (ROCm graph)"; diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index 60fc20de835fb7..53c7ab9d33a08a 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -107,15 +107,13 @@ class StreamExecutor { return AllocateArray(1); } - // Retrieves (loads) a kernel, if one exists. + // Loads a kernel from a MultiKernelLoaderSpec. // // Parameters: // spec: The MultiKernelLoaderSpec is usually generated as a compile-time // constant into an appropriate namespace. - // kernel: Outparam that the kernel is loaded into. A given Kernel - // instantiation should not be loaded into more than once. - virtual absl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { + virtual absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec& spec) { return absl::UnimplementedError("Not Implemented"); } @@ -314,12 +312,6 @@ class StreamExecutor { // underlying platform. virtual dnn::DnnSupport* AsDnn() { return nullptr; } - // Creates a new Kernel object. - // TODO(klucke) Combine with GetKernel. - virtual absl::StatusOr> CreateKernel() { - return absl::UnimplementedError("Kernels are not implemented"); - } - // Creates a new CommandBuffer object. virtual absl::StatusOr> CreateCommandBuffer( CommandBuffer::Mode mode) { From 5e295788b08fa886b8bec4dc1137c76a320db2fe Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Mon, 29 Jul 2024 10:57:18 -0700 Subject: [PATCH 238/376] Integrate LLVM at llvm/llvm-project@4ce3993ee2b6 Updates LLVM usage to match [4ce3993ee2b6](https://github.com/llvm/llvm-project/commit/4ce3993ee2b6) PiperOrigin-RevId: 657257440 --- third_party/llvm/generated.patch | 22 ------------------- third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/workspace.bzl | 4 ++-- .../tsl/third_party/llvm/generated.patch | 22 ------------------- .../tsl/third_party/llvm/workspace.bzl | 4 ++-- 5 files changed, 6 insertions(+), 50 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 21d431fd26ed47..c3926cd2b6eeef 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -36,25 +36,3 @@ diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/pr100095.cpp b/clang/test/Sem +struct D : Helper { + D(Key<> f, int n, int e) : Helper(f, Parser<>, n, e) {} +}; -diff -ruN --strip-trailing-cr a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h ---- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h -+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h -@@ -9,6 +9,7 @@ - #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ - - #include "mlir/Conversion/LLVMCommon/Pattern.h" -+#include "mlir/Dialect/Arith/IR/Arith.h" - #include "mlir/Dialect/GPU/IR/GPUDialect.h" - #include "mlir/Dialect/LLVMIR/LLVMDialect.h" - #include "mlir/IR/Builders.h" -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -5744,6 +5744,7 @@ - "lib/Conversion/GPUCommon/OpToFuncCallLowering.h", - ], - deps = [ -+ ":ArithDialect", - ":GPUDialect", - ":IR", - ":LLVMCommonConversion", diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 072de9d9c6420f..5e0e665f2d4238 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "99bb9a719cec9513e72ad275c1c0302b76b6c408" - LLVM_SHA256 = "af060bd4edd9340fd0b90ddd246c78e87dd374d5998a4c154f31d11f8888a076" + LLVM_COMMIT = "4ce3993ee2b6ee883ef62100df68db9e10ef1dc9" + LLVM_SHA256 = "f0ab7ef30dfad130ce5b7421d4ea33decb9027561a3e944c54be115b81bfe64d" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index c82f3275766f90..090e5cd279fda4 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "8f92b38a2400ce5dc72f97067b02c635ed4f3d00" - SHARDY_SHA256 = "3d91370627e81ce5285e5a6ec0d6dbefc786ae32f6d1ebcb4aa61fd247378b91" + SHARDY_COMMIT = "0419e7d4717291ccdcdd81f404613bea5a0c12ba" + SHARDY_SHA256 = "39a8c62e95eea71d6afb3b24e77253a82d12864085335c048c27e854568cff4f" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index 21d431fd26ed47..c3926cd2b6eeef 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -36,25 +36,3 @@ diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/pr100095.cpp b/clang/test/Sem +struct D : Helper { + D(Key<> f, int n, int e) : Helper(f, Parser<>, n, e) {} +}; -diff -ruN --strip-trailing-cr a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h ---- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h -+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h -@@ -9,6 +9,7 @@ - #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ - - #include "mlir/Conversion/LLVMCommon/Pattern.h" -+#include "mlir/Dialect/Arith/IR/Arith.h" - #include "mlir/Dialect/GPU/IR/GPUDialect.h" - #include "mlir/Dialect/LLVMIR/LLVMDialect.h" - #include "mlir/IR/Builders.h" -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -5744,6 +5744,7 @@ - "lib/Conversion/GPUCommon/OpToFuncCallLowering.h", - ], - deps = [ -+ ":ArithDialect", - ":GPUDialect", - ":IR", - ":LLVMCommonConversion", diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 072de9d9c6420f..5e0e665f2d4238 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "99bb9a719cec9513e72ad275c1c0302b76b6c408" - LLVM_SHA256 = "af060bd4edd9340fd0b90ddd246c78e87dd374d5998a4c154f31d11f8888a076" + LLVM_COMMIT = "4ce3993ee2b6ee883ef62100df68db9e10ef1dc9" + LLVM_SHA256 = "f0ab7ef30dfad130ce5b7421d4ea33decb9027561a3e944c54be115b81bfe64d" tf_http_archive( name = name, From 0b0d556b09fa787f4cd1965676a0e5240b4d3cd1 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 29 Jul 2024 11:37:05 -0700 Subject: [PATCH 239/376] Remove KernelFactory in favor of just calling StreamExecutor::LoadKernel directly. PiperOrigin-RevId: 657272603 --- xla/service/gpu/BUILD | 1 - xla/service/gpu/kernels/BUILD | 3 -- .../cutlass_gemm_custom_kernel_benchmarks.cc | 6 +-- .../cutlass_gemm_custom_kernel_test.cc | 11 ++--- .../gpu/kernels/topk_custom_kernel_test.cc | 11 ++--- xla/service/gpu/runtime/BUILD | 2 - xla/service/gpu/runtime/command_buffer_cmd.cc | 3 +- xla/service/gpu/runtime/kernel_thunk.cc | 7 ++-- xla/service/gpu/stream_executor_util.cc | 3 +- xla/stream_executor/BUILD | 14 ------- xla/stream_executor/host/BUILD | 1 - xla/stream_executor/host/host_kernel_test.cc | 7 +--- xla/stream_executor/kernel_factory.h | 40 ------------------- xla/stream_executor/typed_kernel_factory.h | 3 +- 14 files changed, 18 insertions(+), 94 deletions(-) delete mode 100644 xla/stream_executor/kernel_factory.h diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 1b97905c38eaa8..02810f9c61ece8 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -4241,7 +4241,6 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:data_type", "//xla/stream_executor:dnn", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", "//xla/stream_executor:typed_kernel_factory", diff --git a/xla/service/gpu/kernels/BUILD b/xla/service/gpu/kernels/BUILD index 9d04094c5fd7cc..532c908de0791c 100644 --- a/xla/service/gpu/kernels/BUILD +++ b/xla/service/gpu/kernels/BUILD @@ -231,7 +231,6 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/service:platform_util", "//xla/stream_executor", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", @@ -281,7 +280,6 @@ xla_test( ":cutlass_gemm_custom_kernel", "//xla:xla_data_proto_cc", "//xla/stream_executor", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", @@ -302,7 +300,6 @@ xla_cc_binary( "//xla:xla_data_proto_cc", "//xla/service:gpu_plugin", "//xla/stream_executor", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc index 8d44bb024294e3..22edaea5634405 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc @@ -21,7 +21,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" @@ -59,9 +58,8 @@ static void BM_RowMajorGemm(benchmark::State& state) { /*indices=*/{0, 1, 2}, /*slices=*/{}, device)); const auto& custom_kernel = custom_kernels[0]; - TF_ASSERT_OK_AND_ASSIGN( - auto gemm, - se::KernelFactory::Create(executor, custom_kernel.kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto gemm, + executor->LoadKernel(custom_kernel.kernel_spec())); // Prepare arguments: a=1.1, b=1.2, c=0.0 se::DeviceMemory a = executor->AllocateArray(m * k, 0); diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index 458f31ae88a836..e53a1166a1f5db 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" @@ -50,9 +49,8 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { executor->GetDeviceDescription())); auto custom_kernel = custom_kernels[0]; - TF_ASSERT_OK_AND_ASSIGN( - auto gemm, - se::KernelFactory::Create(executor, custom_kernel.kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto gemm, + executor->LoadKernel(custom_kernel.kernel_spec())); int64_t length = 4 * 4; int64_t byte_length = sizeof(float) * length; @@ -101,9 +99,8 @@ TEST(CutlassGemmKernelTest, LoadFromSharedLibrary) { "cutlass_gemm", kernel_lib_path, PrimitiveType::F32, 4, 4, 4, /*indices=*/{0, 1, 2}, /*slices=*/{}, executor->GetDeviceDescription()); - TF_ASSERT_OK_AND_ASSIGN( - auto gemm, - se::KernelFactory::Create(executor, custom_kernel->kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto gemm, + executor->LoadKernel(custom_kernel->kernel_spec())); int64_t length = 4 * 4; int64_t byte_length = sizeof(float) * length; diff --git a/xla/service/gpu/kernels/topk_custom_kernel_test.cc b/xla/service/gpu/kernels/topk_custom_kernel_test.cc index 0a8a4d9342b81d..974cc975ea5393 100644 --- a/xla/service/gpu/kernels/topk_custom_kernel_test.cc +++ b/xla/service/gpu/kernels/topk_custom_kernel_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "absl/strings/substitute.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" @@ -111,9 +110,8 @@ TEST_P(TopKKernelTest, TopKFloat) { auto custom_kernel = GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); - TF_ASSERT_OK_AND_ASSIGN( - auto kernel, - se::KernelFactory::Create(executor, custom_kernel->kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto kernel, + executor->LoadKernel(custom_kernel->kernel_spec())); // Launch topk kernel with device memory arguments. se::KernelArgsDeviceMemoryArray arr( @@ -166,9 +164,8 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { auto custom_kernel = GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); - TF_ASSERT_OK_AND_ASSIGN( - auto kernel, - se::KernelFactory::Create(executor, custom_kernel->kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto kernel, + executor->LoadKernel(custom_kernel->kernel_spec())); // Launch topk kernel with device memory arguments. se::KernelArgsDeviceMemoryArray arr( diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD index 3b2345c858b91e..af8e91f94752d4 100644 --- a/xla/service/gpu/runtime/BUILD +++ b/xla/service/gpu/runtime/BUILD @@ -87,7 +87,6 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:command_buffer", "//xla/stream_executor:dnn", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:lazy_op_runner", "//xla/stream_executor:trace_command_buffer_factory", "//xla/stream_executor/gpu:gpu_stream_header", @@ -757,7 +756,6 @@ cc_library( "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/kernels:custom_kernel", "//xla/stream_executor", - "//xla/stream_executor:kernel_factory", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", diff --git a/xla/service/gpu/runtime/command_buffer_cmd.cc b/xla/service/gpu/runtime/command_buffer_cmd.cc index fcc38cd2d65514..5a05bb47d2b03b 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -64,7 +64,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/lazy_op_runner.h" #include "xla/stream_executor/stream.h" @@ -714,7 +713,7 @@ absl::Status CustomKernelLaunchCmd::Initialize( TF_ASSIGN_OR_RETURN( std::unique_ptr kernel, - se::KernelFactory::Create(params.executor, custom_kernel_.kernel_spec())); + params.executor->LoadKernel(custom_kernel_.kernel_spec())); absl::MutexLock lock(&mutex_); kernels_.emplace(params.executor, std::move(kernel)); diff --git a/xla/service/gpu/runtime/kernel_thunk.cc b/xla/service/gpu/runtime/kernel_thunk.cc index 3ea5a010658af4..a26de45ddaa853 100644 --- a/xla/service/gpu/runtime/kernel_thunk.cc +++ b/xla/service/gpu/runtime/kernel_thunk.cc @@ -37,7 +37,6 @@ limitations under the License. #include "xla/service/gpu/stream_executor_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" @@ -187,9 +186,9 @@ absl::Status CustomKernelThunk::Initialize(const InitializeParams& params) { auto it = kernel_cache_.find(params.executor); if (kernel_cache_.end() == it) { - TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - se::KernelFactory::Create( - params.executor, custom_kernel_.kernel_spec())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + params.executor->LoadKernel(custom_kernel_.kernel_spec())); kernel_cache_.emplace(params.executor, std::move(kernel)); } diff --git a/xla/service/gpu/stream_executor_util.cc b/xla/service/gpu/stream_executor_util.cc index cde9b554bd504d..8d4020f859c794 100644 --- a/xla/service/gpu/stream_executor_util.cc +++ b/xla/service/gpu/stream_executor_util.cc @@ -52,7 +52,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" @@ -377,7 +376,7 @@ absl::StatusOr> CreateKernel( } TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - se::KernelFactory::Create(stream_exec, loader_spec)); + stream_exec->LoadKernel(loader_spec)); se::KernelMetadata m; m.set_shared_memory_bytes(shared_mem_bytes); diff --git a/xla/stream_executor/BUILD b/xla/stream_executor/BUILD index 2c084fee8894d4..16beff87147f96 100644 --- a/xla/stream_executor/BUILD +++ b/xla/stream_executor/BUILD @@ -616,25 +616,11 @@ cc_library( ], ) -cc_library( - name = "kernel_factory", - hdrs = ["kernel_factory.h"], - deps = [ - ":kernel", - ":kernel_spec", - ":stream_executor_h", - "@com_google_absl//absl/status:statusor", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", - ], -) - cc_library( name = "typed_kernel_factory", hdrs = ["typed_kernel_factory.h"], deps = [ ":kernel", - ":kernel_factory", ":kernel_spec", ":stream_executor_h", "@com_google_absl//absl/status:statusor", diff --git a/xla/stream_executor/host/BUILD b/xla/stream_executor/host/BUILD index 4612a36429af5c..64152a8058400d 100644 --- a/xla/stream_executor/host/BUILD +++ b/xla/stream_executor/host/BUILD @@ -138,7 +138,6 @@ xla_cc_test( ":ptr_host_kernel_function", "//xla/stream_executor", "//xla/stream_executor:device_memory", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:kernel_spec", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/functional:any_invocable", diff --git a/xla/stream_executor/host/host_kernel_test.cc b/xla/stream_executor/host/host_kernel_test.cc index 98157266a74eef..7f0a229f1c92a9 100644 --- a/xla/stream_executor/host/host_kernel_test.cc +++ b/xla/stream_executor/host/host_kernel_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" @@ -157,8 +156,7 @@ TEST(HostKernelTest, Addition3D) { TF_ASSERT_OK_AND_ASSIGN(auto executor, NewStreamExecutor()); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - TF_ASSERT_OK_AND_ASSIGN(auto add, - KernelFactory::Create(executor.get(), spec)); + TF_ASSERT_OK_AND_ASSIGN(auto add, executor->LoadKernel(spec)); const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; TF_ASSERT_OK(stream->Launch(ThreadDim(2, 2, 3), BlockDim(1), *add, kargs)); @@ -184,8 +182,7 @@ TEST(HostKernelTest, JitAddition) { TF_ASSERT_OK_AND_ASSIGN(auto executor, NewStreamExecutor()); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - TF_ASSERT_OK_AND_ASSIGN(auto add, - KernelFactory::Create(executor.get(), spec)); + TF_ASSERT_OK_AND_ASSIGN(auto add, executor->LoadKernel(spec)); const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; TF_ASSERT_OK(stream->Launch(ThreadDim(4), BlockDim(1), *add, kargs)); diff --git a/xla/stream_executor/kernel_factory.h b/xla/stream_executor/kernel_factory.h deleted file mode 100644 index 17e07cd0f97950..00000000000000 --- a/xla/stream_executor/kernel_factory.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_KERNEL_FACTORY_H_ -#define XLA_STREAM_EXECUTOR_KERNEL_FACTORY_H_ - -#include - -#include "absl/status/statusor.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/stream_executor.h" - -namespace stream_executor { - -// Creates Kernels from kernel specifications. -class KernelFactory { - public: - // Creates kernel on a given executor from a given kernel specification. - static inline absl::StatusOr> Create( - StreamExecutor *executor, const MultiKernelLoaderSpec &spec) { - return executor->LoadKernel(spec); - } -}; - -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_KERNEL_FACTORY_H_ diff --git a/xla/stream_executor/typed_kernel_factory.h b/xla/stream_executor/typed_kernel_factory.h index 21600d128ea758..65ed14883152e5 100644 --- a/xla/stream_executor/typed_kernel_factory.h +++ b/xla/stream_executor/typed_kernel_factory.h @@ -25,7 +25,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/statusor.h" @@ -41,7 +40,7 @@ class TypedKernelFactory { static absl::StatusOr> Create( StreamExecutor *executor, const MultiKernelLoaderSpec &spec) { TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - KernelFactory::Create(executor, spec)); + executor->LoadKernel(spec)); return TypedKernel(std::move(kernel)); } From a538b2dd0fbd0be853769682063456b121afca5f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 29 Jul 2024 11:43:38 -0700 Subject: [PATCH 240/376] [XLA] Fixing pattern_matcher include header warnings PiperOrigin-RevId: 657274833 --- xla/service/BUILD | 13 +++++++++++++ xla/service/pattern_matcher.h | 11 ++++++++++- xla/service/pattern_matcher_gmock.h | 3 +++ xla/service/pattern_matcher_gmock_test.cc | 8 ++++++++ xla/service/pattern_matcher_test.cc | 11 +++++++++++ 5 files changed, 45 insertions(+), 1 deletion(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index b4a2419c0a016a..8b2b34fa88eed5 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -964,12 +964,17 @@ cc_library( hdrs = ["pattern_matcher.h"], deps = [ ":hlo_parser", + "//xla:comparison_util", + "//xla:literal", "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:ptrvec", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_absl//absl/utility", @@ -982,11 +987,16 @@ xla_cc_test( deps = [ ":hlo_parser", ":pattern_matcher", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", "//xla:test", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], ) @@ -997,7 +1007,9 @@ cc_library( hdrs = ["pattern_matcher_gmock.h"], deps = [ ":pattern_matcher", + "//xla:shape_util", "//xla:test", + "//xla/hlo/ir:hlo", "@tsl//tsl/platform:test", ], ) @@ -1010,6 +1022,7 @@ xla_cc_test( ":pattern_matcher_gmock", "//xla:shape_util", "//xla:test", + "//xla/hlo/ir:hlo", "//xla/tests:xla_internal_test_main", "@tsl//tsl/platform:test", ], diff --git a/xla/service/pattern_matcher.h b/xla/service/pattern_matcher.h index b17c53a9baf699..9b5a95381c3d86 100644 --- a/xla/service/pattern_matcher.h +++ b/xla/service/pattern_matcher.h @@ -16,34 +16,43 @@ limitations under the License. #ifndef XLA_SERVICE_PATTERN_MATCHER_H_ #define XLA_SERVICE_PATTERN_MATCHER_H_ -#include +#include +#include #include #include #include #include #include #include +#include #include #include #include #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "absl/utility/utility.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/ptrvec.h" +#include "xla/layout.h" #include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/service/hlo_parser.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/xla/service/pattern_matcher_gmock.h b/xla/service/pattern_matcher_gmock.h index e183211d645d50..eeb7b1caabb4e1 100644 --- a/xla/service/pattern_matcher_gmock.h +++ b/xla/service/pattern_matcher_gmock.h @@ -18,7 +18,10 @@ limitations under the License. #include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/layout.h" #include "xla/service/pattern_matcher.h" +#include "xla/shape.h" #include "xla/test.h" #include "tsl/platform/test.h" diff --git a/xla/service/pattern_matcher_gmock_test.cc b/xla/service/pattern_matcher_gmock_test.cc index 81cff291024fe8..c0a279537f686d 100644 --- a/xla/service/pattern_matcher_gmock_test.cc +++ b/xla/service/pattern_matcher_gmock_test.cc @@ -15,7 +15,15 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" +#include +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/service/pattern_matcher.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "tsl/platform/test.h" diff --git a/xla/service/pattern_matcher_test.cc b/xla/service/pattern_matcher_test.cc index cd020c821b0c00..73da06ae7c1eea 100644 --- a/xla/service/pattern_matcher_test.cc +++ b/xla/service/pattern_matcher_test.cc @@ -15,14 +15,25 @@ limitations under the License. #include "xla/service/pattern_matcher.h" +#include +#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal_util.h" #include "xla/service/hlo_parser.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { From e6ef8a63b774c1fb385505b7375affc34a2cd08c Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 29 Jul 2024 13:01:00 -0700 Subject: [PATCH 241/376] [XLA:GPU] Remove wrong expectation from skipped Triton support test. This was not caught because the test is commented out, but we definitely don't intend to fulfill that expectation once we re-enable that test. PiperOrigin-RevId: 657298606 --- xla/service/gpu/fusions/triton/triton_support_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/xla/service/gpu/fusions/triton/triton_support_test.cc b/xla/service/gpu/fusions/triton/triton_support_test.cc index d93817af5efc6e..a47dd530cd228b 100644 --- a/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -488,7 +488,6 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - EXPECT_TRUE(IsTritonSupportedInstruction(ti.Instruction(), cc)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } From a517228de8d0ff179ff77c0ecb7acddb65ae24d6 Mon Sep 17 00:00:00 2001 From: Sara Smoot Date: Mon, 29 Jul 2024 13:21:24 -0700 Subject: [PATCH 242/376] [XLA:GPU] Rename the debug flag xla_gpu_enable_address_computation_fusion to xla_gpu_enable_dynamic_slice_fusion for consistency. "AddressComputation" is confusing, it simply fuses dynamic slice (and dynamic update slice) into other thunks via buffer assignment tricks PiperOrigin-RevId: 657304729 --- xla/debug_options_flags.cc | 9 +++--- .../gpu/dynamic_slice_fusion_rewriter_test.cc | 10 +++--- .../gpu/fusions/dynamic_slice_fusion_test.cc | 32 +++++++++---------- xla/service/gpu/gpu_compiler.cc | 2 +- xla/service/gpu/gpu_compiler_test.cc | 6 ++-- xla/xla.proto | 2 +- 6 files changed, 29 insertions(+), 32 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index c35ea757728c8c..1b9a00e7078add 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -144,7 +144,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_enable_dumping(true); opts.set_xla_gpu_enable_custom_fusions(false); - opts.set_xla_gpu_enable_address_computation_fusion(true); + opts.set_xla_gpu_enable_dynamic_slice_fusion(true); opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); opts.set_xla_gpu_enable_nccl_user_buffers(false); @@ -1298,10 +1298,9 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "expression. Default is all custom fusions registerered in a current " "process.")); flag_list->push_back(tsl::Flag( - "xla_gpu_enable_address_computation_fusion", - bool_setter_for( - &DebugOptions::set_xla_gpu_enable_address_computation_fusion), - debug_options->xla_gpu_enable_address_computation_fusion(), + "xla_gpu_enable_dynamic_slice_fusion", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_dynamic_slice_fusion), + debug_options->xla_gpu_enable_dynamic_slice_fusion(), "Whether to enable XLA address computation fusion")); flag_list->push_back(tsl::Flag( "xla_gpu_nccl_termination_timeout_seconds", diff --git a/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc b/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc index 3d3eef1e4a3687..a539fb5e6ca5cd 100644 --- a/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc +++ b/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc @@ -942,7 +942,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCall) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -990,7 +990,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCallLegacy) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1050,7 +1050,7 @@ TEST_F(DynamicSliceFusionRewriterTest, TupleSliceCustomCallLegacy) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1122,7 +1122,7 @@ TEST_F(DynamicSliceFusionRewriterTest, TupledOutputCustomCallLegacy) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1183,7 +1183,7 @@ TEST_F(DynamicSliceFusionRewriterTest, UnalignedSlice) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); diff --git a/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc index 954d4b656acb0c..f53dc13077729e 100644 --- a/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc +++ b/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc @@ -867,7 +867,7 @@ TEST_F(DynamicSliceFusionTest, CustomCallSimple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1010,12 +1010,12 @@ TEST_F(DynamicSliceFusionTest, CustomCallWithTuple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/true); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1059,12 +1059,12 @@ TEST_F(DynamicSliceFusionTest, NilTuple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1103,12 +1103,12 @@ TEST_F(DynamicSliceFusionTest, CustomCallLegacyAPI) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1141,12 +1141,12 @@ TEST_F(DynamicSliceFusionTest, NilTupleLegacyAPI) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -2460,7 +2460,7 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallSimple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -2532,12 +2532,12 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallWithTuple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/true); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -2639,12 +2639,12 @@ TEST_F(DynamicSliceFusionTest, CustomCallDUS) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -2735,12 +2735,12 @@ TEST_F(DynamicSliceFusionTest, CustomCallDUSTuple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 15debb4020b854..73c3556019d6c2 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1255,7 +1255,7 @@ absl::Status GpuCompiler::OptimizeHloModule( // This is a "low effort, high impact" fusion that should be run first. if (hlo_module->config() .debug_options() - .xla_gpu_enable_address_computation_fusion()) { + .xla_gpu_enable_dynamic_slice_fusion()) { HloPassPipeline pipeline("dynamic-slice"); TF_ASSIGN_OR_RETURN(se::Platform * platform, se::PlatformManager::PlatformWithId(PlatformId())); diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index d95b411c7ac700..7e9e3a419890d1 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -428,8 +428,7 @@ ENTRY main { HloModuleConfig config; DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); - triton_enabled_debug_options.set_xla_gpu_enable_address_computation_fusion( - false); + triton_enabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); triton_enabled_debug_options .set_xla_gpu_require_complete_aot_autotune_results(true); config.set_debug_options(triton_enabled_debug_options); @@ -448,8 +447,7 @@ ENTRY main { GetOptimizedModule(std::move(module))); AutotunerUtil::ClearAutotuneResults(); DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); - triton_disabled_debug_options.set_xla_gpu_enable_address_computation_fusion( - false); + triton_disabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false); config.set_debug_options(triton_disabled_debug_options); TF_ASSERT_OK_AND_ASSIGN(module, diff --git a/xla/xla.proto b/xla/xla.proto index 35e795b1680df8..c5e81a932147ec 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -476,7 +476,7 @@ message DebugOptions { // Enables address computation fusion to optimize dynamic-slice and // dynamic-update-slice operations around library calls. - bool xla_gpu_enable_address_computation_fusion = 105; + bool xla_gpu_enable_dynamic_slice_fusion = 105; reserved 233; // was xla_gpu_enable_gpu2_runtime reserved 234; // was xla_gpu_enable_gpu2_hal From 65923bd8d9a6b0ac47bc4480c79af115a912d2b4 Mon Sep 17 00:00:00 2001 From: pizzud Date: Mon, 29 Jul 2024 13:42:09 -0700 Subject: [PATCH 243/376] [NFC][xla_compile] Only load autotune results when autotuning is enabled. Not all autotuning operations respect xla_gpu_autotune_level if autotune results have been loaded, for reasons that have proven elusive. PiperOrigin-RevId: 657312079 --- xla/service/gpu/BUILD | 2 + xla/service/xla_compile_main.cc | 5 +- xla/tools/BUILD | 3 + xla/tools/xla_compile_lib.cc | 35 ++++++---- xla/tools/xla_compile_lib.h | 9 +++ xla/tools/xla_compile_lib_test.cc | 106 ++++++++++++++++++++++++++++++ 6 files changed, 144 insertions(+), 16 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 02810f9c61ece8..a8b7d3a18d6fb1 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -65,6 +65,8 @@ filegroup( ]), ) +exports_files(srcs = ["gpu_compiler_test_autotune_db.textproto"]) + tf_proto_library( name = "backend_configs", srcs = ["backend_configs.proto"], diff --git a/xla/service/xla_compile_main.cc b/xla/service/xla_compile_main.cc index 1e607a0b674c3b..0e217d53cb7343 100644 --- a/xla/service/xla_compile_main.cc +++ b/xla/service/xla_compile_main.cc @@ -71,7 +71,7 @@ int main(int argc, char* argv[]) { "an attached GPU will be used."), tsl::Flag("autotune_results", &options.gpu_options.autotune_results_path, "The path to AutotuneResults, optional when compiling for" - " GPU"), + " GPU. Only used if autotuning is enabled in XLA_FLAGS."), tsl::Flag("symbol_repo", &options.repo_options.symbol_repo, "Which SymbolRepository to look up --symbol_reference in. If " "the repository contains a GpuTargetConfig, " @@ -83,7 +83,8 @@ int main(int argc, char* argv[]) { "optimized_symbol_reference", &options.repo_options.optimized_symbol_id, "Optimized symbol ID to look up in a SymbolRepository. Overrides " - "--autotune_results_path."), + "--autotune_results_path. Any autotuning results that are present " + "will be used as long as autotuning is enabled in XLA_FLAGS."), tsl::Flag("use_attached_device", &options.gpu_options.use_attached_device, "Whether to use the attached GPU or not. Overrides the " diff --git a/xla/tools/BUILD b/xla/tools/BUILD index 32c2e1c4b6a294..0bbc059f110b0c 100644 --- a/xla/tools/BUILD +++ b/xla/tools/BUILD @@ -819,6 +819,7 @@ xla_test( data = [ ":data/add.hlo", "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt", + "//xla/service/gpu:gpu_compiler_test_autotune_db.textproto", ], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", @@ -830,6 +831,8 @@ xla_test( "//xla/service:platform_util", "//xla/service:symbol_repository", "//xla/service:xla_compile_result_proto_cc_impl", + "//xla/service/gpu:autotuner_util", + "//xla/service/gpu:gpu_symbol_repository", "//xla/stream_executor:device_description_proto_cc", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", diff --git a/xla/tools/xla_compile_lib.cc b/xla/tools/xla_compile_lib.cc index ce195d1a925bbd..fe4289b0e1c7f2 100644 --- a/xla/tools/xla_compile_lib.cc +++ b/xla/tools/xla_compile_lib.cc @@ -232,37 +232,43 @@ ReadModuleFromSymbolRepo(absl::string_view symbol_repo, return mod; } -static absl::StatusOr LoadAutotuneDataFromModule( +static std::unique_ptr ReadTargetConfigFromModule( HloModuleAndMetadata* mod, BackendType backend) { if (backend == BackendType::kGpu) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (auto* data = static_cast( mod->backend_specific_data.get()); - data != nullptr && data->autotune_results.has_value()) { - TF_RETURN_IF_ERROR( - gpu::AutotunerUtil::LoadAutotuneResults(*data->autotune_results)); - return true; + data != nullptr) { + return std::move(mod->target_config); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } - return false; + + return nullptr; } -static std::unique_ptr ReadTargetConfigFromModule( - HloModuleAndMetadata* mod, BackendType backend) { +namespace internal { + +absl::StatusOr LoadAutotuneDataFromModule(HloModuleAndMetadata* mod, + BackendType backend) { if (backend == BackendType::kGpu) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (auto* data = static_cast( mod->backend_specific_data.get()); - data != nullptr) { - return std::move(mod->target_config); + data != nullptr && data->autotune_results.has_value() && + mod->hlo_module->config().debug_options().xla_gpu_autotune_level() > + 0) { + TF_RETURN_IF_ERROR( + gpu::AutotunerUtil::LoadAutotuneResults(*data->autotune_results)); + return true; } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } - - return nullptr; + return false; } +} // namespace internal + absl::Status XlaCompileMain(const XlaCompileOptions& options) { std::unique_ptr hlo_module; std::unique_ptr target_config; @@ -299,7 +305,7 @@ absl::Status XlaCompileMain(const XlaCompileOptions& options) { ReadModuleFromSymbolRepo(symbol_repo, optimized_symbol_id, backend)); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - TF_ASSIGN_OR_RETURN(found_autotune, LoadAutotuneDataFromModule( + TF_ASSIGN_OR_RETURN(found_autotune, internal::LoadAutotuneDataFromModule( optimized_mod.get(), backend)); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } @@ -340,7 +346,8 @@ absl::Status XlaCompileMain(const XlaCompileOptions& options) { if (absl::string_view autotune_results_path = options.gpu_options.autotune_results_path; - !found_autotune && !autotune_results_path.empty()) { + !found_autotune && !autotune_results_path.empty() && + hlo_module->config().debug_options().xla_gpu_autotune_level() > 0) { TF_RETURN_IF_ERROR(gpu::AutotunerUtil::LoadAutotuneResultsFromFile( autotune_results_path)); } diff --git a/xla/tools/xla_compile_lib.h b/xla/tools/xla_compile_lib.h index 8d4f9e0dae8e01..3892f156c913be 100644 --- a/xla/tools/xla_compile_lib.h +++ b/xla/tools/xla_compile_lib.h @@ -84,6 +84,15 @@ struct XlaCompileOptions { // correspond to fields in XlaCompileOptions. absl::Status XlaCompileMain(const XlaCompileOptions& compile_options); +namespace internal { + +// Loads autotuning data if autotuning is enabled and autotuning results are +// present. Returns true if data was present and successfully loaded, false +// otherwise. +absl::StatusOr LoadAutotuneDataFromModule(HloModuleAndMetadata* mod, + BackendType backend); + +} // namespace internal } // namespace xla #endif // XLA_TOOLS_XLA_COMPILE_LIB_H_ diff --git a/xla/tools/xla_compile_lib_test.cc b/xla/tools/xla_compile_lib_test.cc index 6bf9051f221c83..d9586238dd2247 100644 --- a/xla/tools/xla_compile_lib_test.cc +++ b/xla/tools/xla_compile_lib_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/gpu_symbol_repository.h" #include "xla/service/platform_util.h" #include "xla/service/symbol_repository.h" #include "xla/service/xla_compile_result.pb.h" @@ -43,6 +44,10 @@ limitations under the License. #include "tsl/protobuf/error_codes.pb.h" #include "tsl/protobuf/status.pb.h" +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/service/gpu/autotuner_util.h" +#endif + namespace xla { namespace { @@ -216,5 +221,106 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) { EXPECT_EQ(result.status().code(), tensorflow::error::OK); } +TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(LoadAutotuneDataCpu)) { + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kCpu), + IsOkAndHolds(false)); +} + +TEST_F(XlaCompileLibTest, + DISABLED_ON_CPU(LoadAutotuneDataGpuDataPresentAndAutotuningEnabled)) { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + gpu::AutotunerUtil::ClearAutotuneResults(); + + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + auto data = std::make_unique(); + + AutotuneResults autotune_results; + TF_ASSERT_OK(tsl::ReadTextProto( + tsl::Env::Default(), + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "gpu_compiler_test_autotune_db.textproto"), + &autotune_results)); + data->autotune_results = autotune_results; + mod.backend_specific_data = std::move(data); + + DebugOptions opts = mod.hlo_module->config().debug_options(); + opts.set_xla_gpu_autotune_level(3); + mod.hlo_module->mutable_config().set_debug_options(opts); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), + IsOkAndHolds(true)); + EXPECT_FALSE(gpu::AutotunerUtil::ResultCacheIsEmpty()); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} + +TEST_F(XlaCompileLibTest, + DISABLED_ON_CPU(LoadAutotuneDataGpuDataPresentAndAutotuningDisabled)) { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + gpu::AutotunerUtil::ClearAutotuneResults(); + + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + auto data = std::make_unique(); + + AutotuneResults autotune_results; + TF_ASSERT_OK(tsl::ReadTextProto( + tsl::Env::Default(), + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "gpu_compiler_test_autotune_db.textproto"), + &autotune_results)); + data->autotune_results = autotune_results; + mod.backend_specific_data = std::move(data); + + DebugOptions opts = mod.hlo_module->config().debug_options(); + opts.set_xla_gpu_autotune_level(0); + mod.hlo_module->mutable_config().set_debug_options(opts); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), + IsOkAndHolds(false)); + EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty()); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} + +TEST_F(XlaCompileLibTest, + DISABLED_ON_CPU(LoadAutotuneDataGpuDataNotPresentAndAutotuningEnabled)) { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + gpu::AutotunerUtil::ClearAutotuneResults(); + + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + + DebugOptions opts = mod.hlo_module->config().debug_options(); + opts.set_xla_gpu_autotune_level(3); + mod.hlo_module->mutable_config().set_debug_options(opts); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), + IsOkAndHolds(false)); + EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty()); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} + +TEST_F( + XlaCompileLibTest, + DISABLED_ON_CPU(LoadAutotuneDataGpuDataNotPresentAndAutotuningDisabled)) { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + gpu::AutotunerUtil::ClearAutotuneResults(); + + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + + DebugOptions opts = mod.hlo_module->config().debug_options(); + opts.set_xla_gpu_autotune_level(0); + mod.hlo_module->mutable_config().set_debug_options(opts); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), + IsOkAndHolds(false)); + EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty()); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} + } // namespace } // namespace xla From ffd724e235ea674825ed2f7c22538804e263ca4c Mon Sep 17 00:00:00 2001 From: Fangrui Song Date: Mon, 29 Jul 2024 15:02:28 -0700 Subject: [PATCH 244/376] Integrate LLVM at llvm/llvm-project@63e1647827f3 Updates LLVM usage to match [63e1647827f3](https://github.com/llvm/llvm-project/commit/63e1647827f3) PiperOrigin-RevId: 657337524 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/workspace.bzl | 4 ++-- third_party/tsl/third_party/llvm/workspace.bzl | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 5e0e665f2d4238..508b58c2033e18 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "4ce3993ee2b6ee883ef62100df68db9e10ef1dc9" - LLVM_SHA256 = "f0ab7ef30dfad130ce5b7421d4ea33decb9027561a3e944c54be115b81bfe64d" + LLVM_COMMIT = "63e1647827f3427c5f3ad37461d84a63ba5fcdaf" + LLVM_SHA256 = "1d2d4cee6b7c5f558635d16ea97e14c688071f4cc7c6f29cddd23420817263f2" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 090e5cd279fda4..8f4cba75927cd4 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "0419e7d4717291ccdcdd81f404613bea5a0c12ba" - SHARDY_SHA256 = "39a8c62e95eea71d6afb3b24e77253a82d12864085335c048c27e854568cff4f" + SHARDY_COMMIT = "4ef678c17a82b55c03028d0e76df647c5c6be471" + SHARDY_SHA256 = "9e771b112753a406c81d9cff4828be4e3ee8cf302ab2b7ebee0b381d9868b194" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 5e0e665f2d4238..508b58c2033e18 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "4ce3993ee2b6ee883ef62100df68db9e10ef1dc9" - LLVM_SHA256 = "f0ab7ef30dfad130ce5b7421d4ea33decb9027561a3e944c54be115b81bfe64d" + LLVM_COMMIT = "63e1647827f3427c5f3ad37461d84a63ba5fcdaf" + LLVM_SHA256 = "1d2d4cee6b7c5f558635d16ea97e14c688071f4cc7c6f29cddd23420817263f2" tf_http_archive( name = name, From 79a1522854d6b581666ea83a368ebebc7337f9d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 29 Jul 2024 16:02:58 -0700 Subject: [PATCH 245/376] Add matmul test case to collective_permute_decomposer test PiperOrigin-RevId: 657356595 --- xla/hlo/utils/BUILD | 1 + xla/hlo/utils/hlo_query.cc | 18 +++ xla/hlo/utils/hlo_query.h | 8 ++ xla/service/BUILD | 1 + .../collective_permute_decomposer_test.cc | 118 ++++++++++++++++++ ...ollective_permute_cycle_decomposer_test.cc | 13 +- .../collective_pipeline_parallelism_test.cc | 10 +- 7 files changed, 161 insertions(+), 8 deletions(-) diff --git a/xla/hlo/utils/BUILD b/xla/hlo/utils/BUILD index e6f54bb0f623f0..94a0f473f8a0db 100644 --- a/xla/hlo/utils/BUILD +++ b/xla/hlo/utils/BUILD @@ -158,6 +158,7 @@ cc_library( "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/xla/hlo/utils/hlo_query.cc b/xla/hlo/utils/hlo_query.cc index 85e41fff68a149..69a6fef79857fe 100644 --- a/xla/hlo/utils/hlo_query.cc +++ b/xla/hlo/utils/hlo_query.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -268,5 +269,22 @@ HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, return gte; } +bool IsBeforeInComputation(const HloComputation* computation, + absl::string_view inst1, absl::string_view inst2) { + int index1 = -1; + int index2 = -1; + int current_index = 0; + for (auto instruction : computation->instructions()) { + if (instruction->name() == inst1) { + index1 = current_index; + } + if (instruction->name() == inst2) { + index2 = current_index; + } + current_index++; + } + current_index++; + return index1 < index2; +} } // namespace hlo_query } // namespace xla diff --git a/xla/hlo/utils/hlo_query.h b/xla/hlo/utils/hlo_query.h index cda265362d452b..950082accf14f0 100644 --- a/xla/hlo/utils/hlo_query.h +++ b/xla/hlo/utils/hlo_query.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -152,6 +153,13 @@ bool HasX64TransformedHostTransfer(const HloModule& module); HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, int64_t index); +// TODO: b/356153995 - refactor hlo_test_base +// Check that one instruction comes before another one for a given computation. +// The function returns true if the first instruction comes before the second +// one, and false otherwise. This is useful for partial checks on the +// transformed IR without going through a full file check. +bool IsBeforeInComputation(const HloComputation* computation, + absl::string_view inst1, absl::string_view inst2); } // namespace hlo_query } // namespace xla diff --git a/xla/service/BUILD b/xla/service/BUILD index 8b2b34fa88eed5..6765f1c1ff9ae4 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -501,6 +501,7 @@ xla_cc_test( ":hlo_parser", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", + "//xla/hlo/utils:hlo_query", "//xla/service/gpu:backend_configs_cc", "//xla/tests:hlo_test_base", "@com_google_googletest//:gtest", diff --git a/xla/service/collective_permute_decomposer_test.cc b/xla/service/collective_permute_decomposer_test.cc index b80a52b51e9f1a..8de403e030def4 100644 --- a/xla/service/collective_permute_decomposer_test.cc +++ b/xla/service/collective_permute_decomposer_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/hlo_parser.h" @@ -315,6 +316,123 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { HasSubstr("_xla_send_recv_pipeline=\"1\"")); } +TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { + // The HLO module below is generated by passing the HLO in + // CollectiveOpsTest.CollectivePermute_CircularPipelinePreOptimization through + // the collective_permute_cycle_decomposer.transformation. + const char* const kModuleStr = R"( + HloModule test + + while_body { + inputs = (u32[], f32[2,2], f32[2,2]) parameter(0) + iter = u32[] get-tuple-element(inputs), index=0 + iter_increment = u32[] constant(1) + next_iter = u32[] add(iter, iter_increment) + partition-id = u32[] partition-id() + zero = u32[] constant(0) + compare = pred[] compare(partition-id, zero), direction=EQ + broadcast = pred[2,2] broadcast(compare), dimensions={} + + weights = f32[2,2] get-tuple-element(inputs), index=2 + data = f32[2,2] get-tuple-element(inputs), index=1 + + cp_back = f32[2,2] collective-permute(data), channel_id=1, + source_target_pairs={{3,0}}, + frontend_attributes={_xla_send_recv_validation="{{3,10}}"} + cp_forward = f32[2,2] collective-permute(data), channel_id=2, + source_target_pairs={{0,1},{1,2},{2,3}}, + frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9}}"} + + select = f32[2,2] select(broadcast, cp_back, cp_forward) + + matmul = f32[2,2] dot(weights, select), lhs_contracting_dims={1}, rhs_contracting_dims={0} + + ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights) + } + + while_cond { + inputs = (u32[], f32[2,2], f32[2,2]) parameter(0) + iter = u32[] get-tuple-element(inputs), index=0 + max_iter = u32[] constant(3) + ROOT compare = pred[] compare(iter, max_iter), direction=LT + } + + ENTRY test_computation { + start_iter = u32[] constant(0) + input_data = f32[2,2] parameter(0) + input_weights = f32[2,2] parameter(1) + input = (u32[], f32[2,2], f32[2,2]) tuple(start_iter, input_data, input_weights) + while_result = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body + ROOT data_out = f32[2,2] get-tuple-element(while_result), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + HloModule* transformed_module = module.get(); + // Check the annotations and ordering of the decomposed send-recv pairs. + // We expect the recv to come before the send in the while body, both for the + // forward edge ({0,1},{1,2},{2,3}}) and the backward edge ({3,0}). This is + // an XLA invariant that shouldn't be broken (see + // https://openxla.org/xla/operation_semantics#send for details of the + // semantics). + HloInstruction* recv_bwd = FindInstruction(transformed_module, "recv"); + EXPECT_EQ(recv_bwd->channel_id().value(), 1); + auto recv_bwd_frontend_attributes = recv_bwd->frontend_attributes().map(); + EXPECT_EQ(recv_bwd_frontend_attributes.size(), 3); + EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvValidationAttr), + "{{3,10}}"); + EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvPipelineAttr), "0"); + EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), + "{{3,0}}"); + + HloInstruction* send_bwd = FindInstruction(transformed_module, "send"); + auto send_bwd_frontend_attributes = send_bwd->frontend_attributes().map(); + EXPECT_THAT(send_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), + "{{3,0}}"); + + HloInstruction* recv_fwd = FindInstruction(transformed_module, "recv.1"); + EXPECT_EQ(recv_fwd->channel_id().value(), 2); + auto recv_fwd_frontend_attributes = recv_fwd->frontend_attributes().map(); + EXPECT_EQ(recv_fwd_frontend_attributes.size(), 3); + EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1"); + EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), + "{{0,1},{1,2},{2,3}}"); + + HloInstruction* send_fwd = FindInstruction(transformed_module, "send.1"); + auto send_fwd_frontend_attributes = send_fwd->frontend_attributes().map(); + EXPECT_EQ(send_fwd_frontend_attributes.size(), 3); + EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1"); + EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), + "{{0,1},{1,2},{2,3}}"); + + HloComputation* while_body = + FindComputation(transformed_module, "while_body"); + EXPECT_NE(while_body, nullptr); + EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv", "send")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "recv", "recv-done")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "send", "recv-done")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "send", "send-done")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "send-done", "send-done.1")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "recv-done", "send-done.1")); + EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv-done.1", + "send-done.1")); + auto recv_done_fwd = FindInstruction(transformed_module, "recv-done"); + auto recv_done_bwd = FindInstruction(transformed_module, "recv-done.1"); + + // TODO: b/356201477 - Investigate potential NCCL deadlock in + // collective_permute_decomposer + EXPECT_EQ(recv_done_fwd->control_predecessors()[0], send_bwd); + EXPECT_EQ(recv_done_bwd->control_predecessors()[0], send_fwd); +} + TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { const char* const kModuleStr = R"( HloModule module diff --git a/xla/service/gpu/collective_permute_cycle_decomposer_test.cc b/xla/service/gpu/collective_permute_cycle_decomposer_test.cc index 7f297ad1e615f1..19436eeded5f42 100644 --- a/xla/service/gpu/collective_permute_cycle_decomposer_test.cc +++ b/xla/service/gpu/collective_permute_cycle_decomposer_test.cc @@ -150,11 +150,14 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) { iter = u32[] get-tuple-element(param), index=0 data = f32[2,2] get-tuple-element(param), index=1 weights = f32[2,2] get-tuple-element(param), index=2 - matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, rhs_contracting_dims={0} - cp = f32[2,2] collective-permute(matmul), channel_id=1, source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + cp = f32[2,2] collective-permute(data), + channel_id=1, + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}, + frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"} + matmul = f32[2,2] dot(weights, cp), lhs_contracting_dims={1}, rhs_contracting_dims={0} iter_increment = u32[] constant(1) next_iter = u32[] add(iter, iter_increment) - ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, cp, weights) + ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights) } ENTRY test_computation { @@ -178,8 +181,12 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) { DynCast( FindInstruction(module.get(), "collective-permute.1")); EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}")); + EXPECT_THAT(cp1->ToString(), + HasSubstr("_xla_send_recv_validation=\"{{3,10}}\"")); EXPECT_THAT(cp2->ToString(), HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}")); + EXPECT_THAT(cp2->ToString(), + HasSubstr("_xla_send_recv_validation=\"{{0,7},{1,8},{2,9}}\"")); } TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) { diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index a88b1000f4737a..48641e2c17cc52 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -73,13 +73,13 @@ XLA_TEST_F(CollectivePipelineParallelismTest, iter = u32[] get-tuple-element(param), index=0 data = f32[2,2] get-tuple-element(param), index=1 weights = f32[2,2] get-tuple-element(param), index=2 - matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, - rhs_contracting_dims={0} - cp = f32[2,2] collective-permute(matmul), - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + cp = f32[2,2] collective-permute(data), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + matmul = f32[2,2] dot(weights, cp), + lhs_contracting_dims={1}, rhs_contracting_dims={0} iter_increment = u32[] constant(1) next_iter = u32[] add(iter, iter_increment) - ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, cp, weights) + ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights) } ENTRY test_computation { From 4228782cc8ecaa8411e988bae08aa9251507e8e9 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Mon, 29 Jul 2024 16:10:52 -0700 Subject: [PATCH 246/376] [xla] Replace debug option xla_use_shardy with execution option use_shardy_partitioner. Replace the use of xla_use_shardy with use_shardy_partitioner and remove xla_use_shardy. PiperOrigin-RevId: 657359119 --- xla/debug_options_flags.cc | 5 ----- xla/pjrt/cpu/cpu_client.cc | 5 +---- xla/pjrt/pjrt_stream_executor_client.cc | 5 +---- xla/python/xla_client.py | 2 +- xla/python/xla_compiler.cc | 10 +++++----- xla/python/xla_extension/__init__.pyi | 3 +-- xla/service/cpu/cpu_compiler.cc | 2 +- xla/service/gpu/gpu_spmd_pipeline.cc | 2 +- xla/service/gpu/gpu_spmd_pipeline_test.cc | 2 +- xla/service/spmd/shardy/BUILD | 1 - xla/service/spmd/shardy/shardy_call_inliner.cc | 5 +---- xla/service/spmd/shardy/shardy_call_inliner_test.cc | 1 + xla/xla.proto | 5 +---- 13 files changed, 15 insertions(+), 33 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 1b9a00e7078add..a9d316399f78fe 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -271,8 +271,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_terminate_on_error(false); - opts.set_xla_use_shardy(false); - opts.set_xla_gpu_shard_autotuning(false); opts.set_xla_syntax_sugar_async_ops(false); @@ -1796,9 +1794,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_nccl_terminate_on_error), debug_options->xla_gpu_nccl_terminate_on_error(), "If set, then NCCL errors will terminate the process.")); - flag_list->push_back(tsl::Flag( - "xla_use_shardy", bool_setter_for(&DebugOptions::set_xla_use_shardy), - debug_options->xla_use_shardy(), "Whether to use Shardy.")); flag_list->push_back(tsl::Flag( "xla_gpu_shard_autotuning", bool_setter_for(&DebugOptions::set_xla_gpu_shard_autotuning), diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index a832ab568d3408..3b2caa64aa7f70 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -857,10 +857,7 @@ absl::StatusOr> TfrtCpuClient::Compile( TF_RETURN_IF_ERROR(MlirToXlaComputation( module, xla_computation, /*use_tuple_args=*/options.parameter_is_tupled_arguments, - /*return_tuple=*/false, - exec_build_options.has_debug_options() - ? exec_build_options.debug_options().xla_use_shardy() - : false)); + /*return_tuple=*/false, exec_build_options.use_shardy_partitioner())); return Compile(xla_computation, options); } diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index 9d3c820550a3b7..7204b158f339f7 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -3524,10 +3524,7 @@ PjRtStreamExecutorClient::Compile(mlir::ModuleOp module, TF_RETURN_IF_ERROR(MlirToXlaComputation( module, xla_computation, /*use_tuple_args=*/options.parameter_is_tupled_arguments, - /*return_tuple=*/false, - exec_build_options.has_debug_options() - ? exec_build_options.debug_options().xla_use_shardy() - : false)); + /*return_tuple=*/false, exec_build_options.use_shardy_partitioner())); // If the compile options specify argument layout, then let's // fall back to using the options to determine layouts. diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 97f01bec9bb0d5..7d8a2e3c05a9fe 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 279 +_version = 280 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_compiler.cc b/xla/python/xla_compiler.cc index 2259083a2da478..7e2504c7ba44e2 100644 --- a/xla/python/xla_compiler.cc +++ b/xla/python/xla_compiler.cc @@ -1199,10 +1199,7 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { &DebugOptions::xla_gpu_dump_autotune_logs_to, [](DebugOptions* self, std::string value) { self->set_xla_gpu_dump_autotune_logs_to(value); - }) - // TODO(b/352486192): Move this to `ExecutableBuildOptions`. - .def_prop_rw("xla_use_shardy", &DebugOptions::xla_use_shardy, - &DebugOptions::set_xla_use_shardy); + }); nb::class_(m, "ExecutableBuildOptions") .def(nb::init<>()) @@ -1276,7 +1273,10 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { [](ExecutableBuildOptions& options, std::vector values) { absl::InlinedVector v(values.begin(), values.end()); options.set_allow_spmd_sharding_propagation_to_output(v); - }); + }) + .def_prop_rw("use_shardy_partitioner", + &ExecutableBuildOptions::use_shardy_partitioner, + &ExecutableBuildOptions::set_use_shardy_partitioner); nb::enum_ op_sharding_type(m, "OpSharding_Type", nb::is_arithmetic()); diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index e19bf8546491ab..a0e9008f81de68 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -318,8 +318,6 @@ class DebugOptions: xla_gpu_dump_autotune_results_to: str xla_gpu_load_autotune_results_from: str xla_gpu_dump_autotune_logs_to: str - # TODO(b/352486192): Move this to `ExecutableBuildOptions`. - xla_use_shardy: bool class CompiledMemoryStats: generated_code_size_in_bytes: int @@ -348,6 +346,7 @@ class ExecutableBuildOptions: use_auto_spmd_partitioning: bool auto_spmd_partitioning_mesh_shape: List[int] auto_spmd_partitioning_mesh_ids: List[int] + use_shardy_partitioner: bool class PrecisionConfig_Precision(enum.IntEnum): DEFAULT: int diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 7254f2b1380f03..024a85edfb9632 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -447,7 +447,7 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( spmd_pipeline.AddPass(); spmd_pipeline.AddPass(); spmd_pipeline.AddPass(); - if (module->config().debug_options().xla_use_shardy()) { + if (module->config().use_shardy_partitioner()) { spmd_pipeline.AddPass(); } else { spmd_pipeline.AddPass( diff --git a/xla/service/gpu/gpu_spmd_pipeline.cc b/xla/service/gpu/gpu_spmd_pipeline.cc index 4f7635813e28be..d84797d21c462e 100644 --- a/xla/service/gpu/gpu_spmd_pipeline.cc +++ b/xla/service/gpu/gpu_spmd_pipeline.cc @@ -89,7 +89,7 @@ void AddSPMDPasses( const HloModuleConfig& config = hlo_module->config(); - if (config.debug_options().xla_use_shardy()) { + if (config.use_shardy_partitioner()) { spmd_pipeline.AddPass(); } else { spmd_pipeline.AddPass(); diff --git a/xla/service/gpu/gpu_spmd_pipeline_test.cc b/xla/service/gpu/gpu_spmd_pipeline_test.cc index 765b73b9590de6..42a9e7dcad49f9 100644 --- a/xla/service/gpu/gpu_spmd_pipeline_test.cc +++ b/xla/service/gpu/gpu_spmd_pipeline_test.cc @@ -48,6 +48,7 @@ class GpuSpmdPartitioningTest : public HloTestBase, HloModuleConfig config = GetModuleConfigForTest( /*replica_count=*/1, /*num_partitions=*/num_devices); config.set_num_partitions(num_devices); + config.set_use_shardy_partitioner(UseShardy()); TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module, config)); @@ -67,7 +68,6 @@ class GpuSpmdPartitioningTest : public HloTestBase, DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_use_shardy(UseShardy()); return debug_options; } }; diff --git a/xla/service/spmd/shardy/BUILD b/xla/service/spmd/shardy/BUILD index 4a4af6b9ee1b61..a96d6250b421b5 100644 --- a/xla/service/spmd/shardy/BUILD +++ b/xla/service/spmd/shardy/BUILD @@ -37,7 +37,6 @@ cc_library( xla_cc_test( name = "shardy_call_inliner_test", srcs = ["shardy_call_inliner_test.cc"], - env = {"XLA_FLAGS": "--xla_use_shardy=true"}, deps = [ ":shardy_call_inliner", "//xla/hlo/ir:hlo", diff --git a/xla/service/spmd/shardy/shardy_call_inliner.cc b/xla/service/spmd/shardy/shardy_call_inliner.cc index 73a8479dcc4fc9..9f863e23a6715d 100644 --- a/xla/service/spmd/shardy/shardy_call_inliner.cc +++ b/xla/service/spmd/shardy/shardy_call_inliner.cc @@ -24,10 +24,7 @@ namespace xla { bool ShardyCallInliner::IsInlineableCallOp(HloInstruction* instruction) const { return CallInliner::IsInlineableCallOp(instruction) && !instruction->has_backend_config() && - !(instruction->GetModule() - ->config() - .debug_options() - .xla_use_shardy() && + !(instruction->GetModule()->config().use_shardy_partitioner() && absl::StrContains(instruction->to_apply()->name(), "shmap_body")); } diff --git a/xla/service/spmd/shardy/shardy_call_inliner_test.cc b/xla/service/spmd/shardy/shardy_call_inliner_test.cc index 861e934fee5779..00d952b3b80461 100644 --- a/xla/service/spmd/shardy/shardy_call_inliner_test.cc +++ b/xla/service/spmd/shardy/shardy_call_inliner_test.cc @@ -45,6 +45,7 @@ TEST_F(ShardyCallInlinerTest, MhloToHloShmapBodyNotInlined) { ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); VLOG(1) << module->ToString(); // The single call in the module is not inlined. diff --git a/xla/xla.proto b/xla/xla.proto index c5e81a932147ec..50993c6c411ea4 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -806,10 +806,7 @@ message DebugOptions { // If true, Nccl errors will terminate the process. bool xla_gpu_nccl_terminate_on_error = 301; - // Use Shardy, a new partitioner, to replace the existing - // ShardingPropagation and SpmdPartitioner. See go/xla-sdy-pipeline for - // details. - bool xla_use_shardy = 302; + reserved 302; // was xla_use_shardy bool xla_gpu_shard_autotuning = 304; From cbec301b478c0c810869654c032fdcde4fda728e Mon Sep 17 00:00:00 2001 From: Heiner Date: Mon, 29 Jul 2024 17:11:50 -0700 Subject: [PATCH 247/376] Add Gloo support for MacOS. This is an alternative to #7726. Gloo supports MacOS, but requires using libuv as the transport mechanism. This closes https://github.com/openxla/xla/pull/15027 PiperOrigin-RevId: 657376600 --- third_party/gloo/gloo.BUILD | 2 +- third_party/uv/BUILD | 1 + third_party/uv/uv.BUILD | 29 +++++++++++++++++++++++++++++ third_party/uv/workspace.bzl | 17 +++++++++++++++++ workspace2.bzl | 2 ++ xla/python/BUILD | 6 +++++- xla/python/xla.cc | 25 +++++++++++++++++++++++-- 7 files changed, 78 insertions(+), 4 deletions(-) create mode 100644 third_party/uv/BUILD create mode 100644 third_party/uv/uv.BUILD create mode 100644 third_party/uv/workspace.bzl diff --git a/third_party/gloo/gloo.BUILD b/third_party/gloo/gloo.BUILD index 1cfc72f1ec35be..2eb62cd7416136 100644 --- a/third_party/gloo/gloo.BUILD +++ b/third_party/gloo/gloo.BUILD @@ -22,7 +22,7 @@ substitions = { "#cmakedefine01 GLOO_USE_REDIS": "#define GLOO_USE_REDIS 0", "#cmakedefine01 GLOO_USE_IBVERBS": "#define GLOO_USE_IBVERBS 0", "#cmakedefine01 GLOO_USE_MPI": "#define GLOO_USE_MPI 0", - "#cmakedefine01 GLOO_USE_LIBUV": "#define GLOO_USE_LIBUV 0", + "#cmakedefine01 GLOO_USE_LIBUV": "#define GLOO_USE_LIBUV (__APPLE__ ? 1 : 0)", "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP": "#define GLOO_HAVE_TRANSPORT_TCP 1", "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS": "#define GLOO_HAVE_TRANSPORT_TCP_TLS 0", "#cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS": "#define GLOO_HAVE_TRANSPORT_IBVERBS 0", diff --git a/third_party/uv/BUILD b/third_party/uv/BUILD new file mode 100644 index 00000000000000..3c413807167aeb --- /dev/null +++ b/third_party/uv/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/uv/uv.BUILD b/third_party/uv/uv.BUILD new file mode 100644 index 00000000000000..75a2df39c435d0 --- /dev/null +++ b/third_party/uv/uv.BUILD @@ -0,0 +1,29 @@ +# Description: +# libuv is a cross-platform asynchronous I/O library. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +cc_library( + name = "uv", + srcs = glob(["src/*.c"]), + hdrs = [ + "include/uv.h", + ], + copts = [ + "-fexceptions", + "-Wno-unused-variable", + ], + includes = [ + "include", + "src", + ], + textual_hdrs = [ + "include/uv.h", + ], +) diff --git a/third_party/uv/workspace.bzl b/third_party/uv/workspace.bzl new file mode 100644 index 00000000000000..8d26ab4dcd41b5 --- /dev/null +++ b/third_party/uv/workspace.bzl @@ -0,0 +1,17 @@ +"""Provides the repository macro to import libuv.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports libuv.""" + + UV_VERSION = "v1.38.0" + UV_SHA256 = "71344f62c5020ed3643ad0bcba98ae4d7d6037285923c5416844d7c141a3ff93" + + tf_http_archive( + name = "uv", + sha256 = UV_SHA256, + strip_prefix = "libuv-{version}".format(version = UV_VERSION), + urls = tf_mirror_urls("https://dist.libuv.org/dist/{version}/libuv-{version}.tar.gz".format(version = UV_VERSION)), + build_file = "//third_party/uv:uv.BUILD", + ) diff --git a/workspace2.bzl b/workspace2.bzl index 5c9d4650408134..aca054224c0f58 100644 --- a/workspace2.bzl +++ b/workspace2.bzl @@ -16,6 +16,7 @@ load("//third_party/robin_map:workspace.bzl", robin_map = "repo") load("//third_party/shardy:workspace.bzl", shardy = "repo") load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") load("//third_party/triton:workspace.bzl", triton = "repo") +load("//third_party/uv:workspace.bzl", uv = "repo") def _initialize_third_party(): """ Load third party repositories. See above load() statements. """ @@ -27,6 +28,7 @@ def _initialize_third_party(): shardy() stablehlo() triton() + uv() # Define all external repositories required by TensorFlow def _tf_repositories(): diff --git a/xla/python/BUILD b/xla/python/BUILD index 8354fb5f097b88..93ecd1a9a7dd6b 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -18,6 +18,7 @@ load( "//xla/tsl:tsl.bzl", "if_cuda_or_rocm", "if_google", + "if_oss", "internal_visibility", ) load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_pybind_extension") @@ -1400,7 +1401,10 @@ tsl_pybind_extension( "@tsl//tsl/platform/cloud:gcs_file_system", ] + select({ # gloo transport only builds on linux - "//xla/tsl:macos": [], + "//xla/tsl:macos": [ + "//xla/pjrt/cpu:gloo_collectives", + "//xla/pjrt/cpu:gloo_kv_store", + ] + if_oss(["@gloo//:transport_uv"]), "//xla/tsl:windows": [], "//conditions:default": [ "//xla/pjrt/cpu:gloo_collectives", diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 19a9d94e1d1b7b..b9f9b6e671c94a 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -71,6 +71,10 @@ limitations under the License. #include "gloo/transport/tcp/device.h" #include "xla/pjrt/cpu/gloo_collectives.h" #include "xla/pjrt/cpu/gloo_kv_store.h" +#elif __APPLE__ +#include "gloo/transport/uv/device.h" +#include "xla/pjrt/cpu/gloo_collectives.h" +#include "xla/pjrt/cpu/gloo_kv_store.h" #endif // __linux__ #if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) @@ -257,7 +261,7 @@ NB_MODULE(xla_extension, m_nb) { std::optional hostname, std::optional interface) -> std::shared_ptr { -#ifdef __linux__ +#if defined(__linux__) std::shared_ptr kv_store = nullptr; if (distributed_client != nullptr) { kv_store = GetDistributedKeyValueStore(distributed_client, @@ -274,9 +278,26 @@ NB_MODULE(xla_extension, m_nb) { auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); return std::make_shared(std::move(gloo_kv_store), std::move(tcp_device)); +#elif defined(__APPLE__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto uv_attrs = gloo::transport::uv::attr(); + if (hostname) { + uv_attrs.hostname = *hostname; + } + if (interface) { + uv_attrs.iface = *interface; + } + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(uv_device)); #else // __linux__ throw xla::XlaRuntimeError( - "make_gloo_tcp_collectives only implemented for linux"); + "make_gloo_tcp_collectives only implemented for linux and macos"); #endif // __linux__ }, nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, From f146cf0f8fecf65a770fd1af9b4cf689d6761e5a Mon Sep 17 00:00:00 2001 From: Fangrui Song Date: Mon, 29 Jul 2024 17:43:52 -0700 Subject: [PATCH 248/376] Integrate LLVM at llvm/llvm-project@de5aa8d0060c Updates LLVM usage to match [de5aa8d0060c](https://github.com/llvm/llvm-project/commit/de5aa8d0060c) PiperOrigin-RevId: 657385511 --- third_party/llvm/generated.patch | 37 ------------------- third_party/llvm/workspace.bzl | 4 +- third_party/shardy/workspace.bzl | 4 +- .../tsl/third_party/llvm/generated.patch | 37 ------------------- .../tsl/third_party/llvm/workspace.bzl | 4 +- 5 files changed, 6 insertions(+), 80 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index c3926cd2b6eeef..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,38 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp ---- a/clang/lib/Sema/SemaTemplateDeduction.cpp -+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp -@@ -951,9 +951,11 @@ - - // Skip over the pack elements that were expanded into separate arguments. - // If we partially expanded, this is the number of partial arguments. -+ // FIXME: `&& FixedNumExpansions` is a workaround for UB described in -+ // https://github.com/llvm/llvm-project/issues/100095 - if (IsPartiallyExpanded) - PackElements += NumPartialPackArgs; -- else if (IsExpanded) -+ else if (IsExpanded && FixedNumExpansions) - PackElements += *FixedNumExpansions; - - for (auto &Pack : Packs) { -diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/pr100095.cpp b/clang/test/SemaCXX/pr100095.cpp ---- a/clang/test/SemaCXX/pr100095.cpp -+++ b/clang/test/SemaCXX/pr100095.cpp -@@ -0,0 +1,17 @@ -+// RUN: %clang_cc1 -fsyntax-only -std=c++11 %s -+// XFAIL: asserts -+ -+template struct Pair; -+template struct Tuple { -+ template Tuple(_Up); -+}; -+template struct StatusOr; -+template using ElementType = int; -+template -+using Key = Tuple...>; -+template -+StatusOr>> Parser(); -+struct Helper { Helper(Tuple<>, Tuple<>, int, int); }; -+struct D : Helper { -+ D(Key<> f, int n, int e) : Helper(f, Parser<>, n, e) {} -+}; diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 508b58c2033e18..c9d1d7e658a079 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "63e1647827f3427c5f3ad37461d84a63ba5fcdaf" - LLVM_SHA256 = "1d2d4cee6b7c5f558635d16ea97e14c688071f4cc7c6f29cddd23420817263f2" + LLVM_COMMIT = "de5aa8d0060cbe286c9cbae90ca8f197b92a3956" + LLVM_SHA256 = "dbad8fbbf8b639a05b171d8fb467ab4b88b8dda7aa46148c7ba6c6c194ca12d7" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 8f4cba75927cd4..6f858b22baa66d 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "4ef678c17a82b55c03028d0e76df647c5c6be471" - SHARDY_SHA256 = "9e771b112753a406c81d9cff4828be4e3ee8cf302ab2b7ebee0b381d9868b194" + SHARDY_COMMIT = "052cf0bd0ddcae0e295c25dffa14f5ffe8d07d2d" + SHARDY_SHA256 = "66694f413683bdfd82b7e5a289117d8f6c6451012658274189207ff755cd9626" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index c3926cd2b6eeef..509398da979e83 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1,38 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp ---- a/clang/lib/Sema/SemaTemplateDeduction.cpp -+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp -@@ -951,9 +951,11 @@ - - // Skip over the pack elements that were expanded into separate arguments. - // If we partially expanded, this is the number of partial arguments. -+ // FIXME: `&& FixedNumExpansions` is a workaround for UB described in -+ // https://github.com/llvm/llvm-project/issues/100095 - if (IsPartiallyExpanded) - PackElements += NumPartialPackArgs; -- else if (IsExpanded) -+ else if (IsExpanded && FixedNumExpansions) - PackElements += *FixedNumExpansions; - - for (auto &Pack : Packs) { -diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/pr100095.cpp b/clang/test/SemaCXX/pr100095.cpp ---- a/clang/test/SemaCXX/pr100095.cpp -+++ b/clang/test/SemaCXX/pr100095.cpp -@@ -0,0 +1,17 @@ -+// RUN: %clang_cc1 -fsyntax-only -std=c++11 %s -+// XFAIL: asserts -+ -+template struct Pair; -+template struct Tuple { -+ template Tuple(_Up); -+}; -+template struct StatusOr; -+template using ElementType = int; -+template -+using Key = Tuple...>; -+template -+StatusOr>> Parser(); -+struct Helper { Helper(Tuple<>, Tuple<>, int, int); }; -+struct D : Helper { -+ D(Key<> f, int n, int e) : Helper(f, Parser<>, n, e) {} -+}; diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 508b58c2033e18..c9d1d7e658a079 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "63e1647827f3427c5f3ad37461d84a63ba5fcdaf" - LLVM_SHA256 = "1d2d4cee6b7c5f558635d16ea97e14c688071f4cc7c6f29cddd23420817263f2" + LLVM_COMMIT = "de5aa8d0060cbe286c9cbae90ca8f197b92a3956" + LLVM_SHA256 = "dbad8fbbf8b639a05b171d8fb467ab4b88b8dda7aa46148c7ba6c6c194ca12d7" tf_http_archive( name = name, From 809a81a8e86d71fd69d779c7f889eae5bb0dd1af Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 30 Jul 2024 02:15:58 -0700 Subject: [PATCH 249/376] Update test expectations The attribute "largest" is not printed anymore if it has the default value. PiperOrigin-RevId: 657502082 --- xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir | 4 ++-- xla/translate/hlo_to_mhlo/tests/import.hlo | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index 4324c8e7731b2b..49f1de75619860 100644 --- a/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -2802,7 +2802,7 @@ func.func @tan_f32(%arg : tensor) -> tensor { // CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>) func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { // CHECK-HIGH-LEVEL: mhlo.topk - // CHECK: %values, %indices = mhlo.topk(%arg0, k = 8, largest = true) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) + // CHECK: %values, %indices = mhlo.topk(%arg0, k = 8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) func.return %1#0, %1#1 : tensor<16x8xf32>, tensor<16x8xi32> } @@ -2814,7 +2814,7 @@ func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32 // CHECK-SAME: -> (tensor, tensor) func.func @dyn_top_k(%arg0: tensor) -> (tensor, tensor) { // CHECK-HIGH-LEVEL: mhlo.topk - // CHECK: %values, %indices = mhlo.topk(%arg0, k = 2, largest = true) : tensor -> (tensor, tensor) + // CHECK: %values, %indices = mhlo.topk(%arg0, k = 2) : tensor -> (tensor, tensor) %values, %indices = chlo.top_k(%arg0, k = 2) : tensor -> (tensor, tensor) return %values, %indices : tensor, tensor } diff --git a/xla/translate/hlo_to_mhlo/tests/import.hlo b/xla/translate/hlo_to_mhlo/tests/import.hlo index a8ce57c90f5d3d..0c175bc850e32e 100644 --- a/xla/translate/hlo_to_mhlo/tests/import.hlo +++ b/xla/translate/hlo_to_mhlo/tests/import.hlo @@ -2021,7 +2021,7 @@ add { } // CHECK-LABEL: func private @test_topk // CHECK-SAME: ([[ARG:%.*]]: tensor<4x4xf32>) -> tuple, tensor<4x2xi32>> -// CHECK: mhlo.topk([[ARG]], k = 2, largest = true) : tensor<4x4xf32> -> (tensor<4x2xf32>, tensor<4x2xi32>) +// CHECK: mhlo.topk([[ARG]], k = 2) : tensor<4x4xf32> -> (tensor<4x2xf32>, tensor<4x2xi32>) // FLATTEN-CHECK-LABEL: func private @test_topk // FLATTEN-CHECK-SAME: ([[ARG:%.*]]: tensor<4x4xf32>) -> (tensor<4x2xf32>, tensor<4x2xi32>) From 5cebc82aa03721c0d2f5446f7e5e027f79320f50 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 30 Jul 2024 02:52:58 -0700 Subject: [PATCH 250/376] Remove visit_arg callback from HloBfsConsumersFirstTraversal. This is not needed anymore, we can use GetParameters() instead for the two places that actually need to access the fusion adaptor operands as well. PiperOrigin-RevId: 657511596 --- xla/service/gpu/fusions/loop.cc | 24 ++++----- xla/service/gpu/fusions/reduction_base.cc | 10 ++-- xla/service/gpu/hlo_traversal.cc | 21 +++----- xla/service/gpu/hlo_traversal.h | 4 +- xla/service/gpu/hlo_traversal_test.cc | 60 ++++++++--------------- 5 files changed, 46 insertions(+), 73 deletions(-) diff --git a/xla/service/gpu/fusions/loop.cc b/xla/service/gpu/fusions/loop.cc index e9b7933b1c7895..522dc1d4a10452 100644 --- a/xla/service/gpu/fusions/loop.cc +++ b/xla/service/gpu/fusions/loop.cc @@ -81,13 +81,13 @@ int ComputeMaxUnrollFactor(int64_t num_elements) { std::pair RowVectorizationEnabled( const HloFusionAdaptor& fusion, int64_t out_rank) { auto roots = fusion.GetRoots(); - const auto is_row_major = [](auto instr) { + const auto is_row_major = [](const HloInstruction* instr) { // Only tested when the inputs are row-major. So only enable that case. // Maybe it would work if only the inner dimensions is contiguous. - return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout()); + return LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout()); }; bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() && - is_row_major(roots[0]); + is_row_major(&roots[0].instruction()); if (!row_vectorized) { return {false, 0}; } @@ -136,15 +136,17 @@ std::pair RowVectorizationEnabled( row_vectorized = false; return TraversalResult::kInterrupt; } - }, - [&](auto argument) { - if (argument.shape().rank() == out_rank) { - ++num_big_inputs; - } - if (!is_row_major(argument)) { - row_vectorized = false; - } }); + if (row_vectorized) { + for (const HloInstruction* argument : fusion.GetParameters()) { + if (argument->shape().rank() == out_rank) { + ++num_big_inputs; + } + if (!is_row_major(argument)) { + row_vectorized = false; + } + }; + } // Trigger only when there is a row broadcasting. return std::make_pair(row_vectorized && some_row_broadcasting, num_big_inputs); diff --git a/xla/service/gpu/fusions/reduction_base.cc b/xla/service/gpu/fusions/reduction_base.cc index cf2aa130082125..7895108db31cc1 100644 --- a/xla/service/gpu/fusions/reduction_base.cc +++ b/xla/service/gpu/fusions/reduction_base.cc @@ -183,11 +183,12 @@ ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis, } absl::flat_hash_set instructions; - + for (const HloInstruction* operand : analysis.fusion().GetParameters()) { + instructions.insert(HloInstructionAdaptor{*operand, &analysis.fusion()}); + } auto visit = [&](absl::Span roots) { HloBfsConsumersFirstTraversal( - roots, analysis.fusion(), - [&](HloInstructionAdaptor consumer) { + roots, analysis.fusion(), [&](HloInstructionAdaptor consumer) { auto& consumer_reachable = reachable_outputs[consumer]; for (auto producer : consumer.GetOperands()) { reachable_outputs[producer].insert(consumer_reachable.begin(), @@ -195,8 +196,7 @@ ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis, } instructions.insert(consumer); return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor argument) { instructions.insert(argument); }); + }); }; // The legacy emitter grouping is buggy: it does not visit instructions in the diff --git a/xla/service/gpu/hlo_traversal.cc b/xla/service/gpu/hlo_traversal.cc index 4394226dfadc0b..dfa655276b0fdf 100644 --- a/xla/service/gpu/hlo_traversal.cc +++ b/xla/service/gpu/hlo_traversal.cc @@ -506,7 +506,6 @@ void HloBfsTraversal( const HloFusionAdaptor& fusion, const std::function& visit_node, - const std::function& visit_arg, bool visit_operands) { absl::flat_hash_set visited; std::queue q; @@ -514,12 +513,8 @@ void HloBfsTraversal( const auto& adjacent_nodes = visit_operands ? node.GetOperands() : node.GetUsers(); for (const auto& node : adjacent_nodes) { - if (visited.insert(node).second) { - if (fusion.ContainsInstruction(node)) { - q.push(node); - } else { - visit_arg(node); - } + if (fusion.ContainsInstruction(node) && visited.insert(node).second) { + q.push(node); } } }; @@ -548,9 +543,8 @@ void HloBfsConsumersFirstTraversal( absl::Span roots, const HloFusionAdaptor& fusion, const std::function& - visit_node, - const std::function& visit_arg) { - HloBfsTraversal(roots, fusion, visit_node, visit_arg, + visit_node) { + HloBfsTraversal(roots, fusion, visit_node, /*visit_operands=*/true); } @@ -559,9 +553,8 @@ void HloBfsProducersFirstTraversal( const HloFusionAdaptor& fusion, const std::function& visit_node) { - HloBfsTraversal( - producers, fusion, visit_node, [](HloInstructionAdaptor) {}, - /*visit_operands=*/false); + HloBfsTraversal(producers, fusion, visit_node, + /*visit_operands=*/false); } bool HloBfsAnyOf(absl::Span roots, @@ -592,7 +585,7 @@ std::optional HloBfsFindIf( } return TraversalResult::kAdvance; }, - [](HloInstructionAdaptor) {}, visit_operands); + visit_operands); return result; } diff --git a/xla/service/gpu/hlo_traversal.h b/xla/service/gpu/hlo_traversal.h index b49d4efc9377ce..67edd2258bb563 100644 --- a/xla/service/gpu/hlo_traversal.h +++ b/xla/service/gpu/hlo_traversal.h @@ -147,9 +147,7 @@ void HloBfsConsumersFirstTraversal( absl::Span roots, const HloFusionAdaptor& fusion, const std::function& - visit_node, - const std::function& visit_arg = - [](HloInstructionAdaptor) {}); + visit_node); // Visit the HLO nodes starting from `producers` in BFS order following the // `user` edges. Each node will be visited exactly once. diff --git a/xla/service/gpu/hlo_traversal_test.cc b/xla/service/gpu/hlo_traversal_test.cc index ee3a4b7ad1239f..fcab8f47a3100f 100644 --- a/xla/service/gpu/hlo_traversal_test.cc +++ b/xla/service/gpu/hlo_traversal_test.cc @@ -150,43 +150,31 @@ TEST_F(HloTraversalTest, AdaptorUsers) { TEST_F(HloTraversalTest, TraverseFusionConsumerFirst) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); std::vector visited_nodes; - std::vector visited_args; auto fusion = HloFusionAdaptor::ForInstruction( module->entry_computation()->GetInstructionWithName("fusion")); - HloBfsConsumersFirstTraversal( - fusion->GetRoots(), *fusion, - [&](HloInstructionAdaptor node) { - visited_nodes.emplace_back(node.name()); - return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor arg) { - visited_args.emplace_back(arg.name()); - }); + HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { + visited_nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; + }); EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul")); - EXPECT_THAT(visited_args, ElementsAre("p0", "negate")); } TEST_F(HloTraversalTest, TraverseFusionConsumerFirstFromFusionRootAndInnerNode) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); std::vector visited_nodes; - std::vector visited_args; auto fusion = HloFusionAdaptor::ForInstruction( module->entry_computation()->GetInstructionWithName("fusion")); auto root = fusion->GetRoots()[0]; - HloBfsConsumersFirstTraversal( - {root, root.GetOperand(0)}, *fusion, - [&](HloInstructionAdaptor node) { - visited_nodes.emplace_back(node.name()); - return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor arg) { - visited_args.emplace_back(arg.name()); - }); + HloBfsConsumersFirstTraversal({root, root.GetOperand(0)}, *fusion, + [&](HloInstructionAdaptor node) { + visited_nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; + }); EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul")); - EXPECT_THAT(visited_args, ElementsAre("p0", "negate")); } TEST_F(HloTraversalTest, TraverseFusionProducerFirst) { @@ -379,17 +367,13 @@ TEST_F(HloTraversalTest, FuseFusionConsumer) { EXPECT_TRUE(reduce_1.GetUsers().empty()); std::vector nodes; - std::vector params; - HloBfsConsumersFirstTraversal( - fusion->GetRoots(), *fusion, - [&](HloInstructionAdaptor node) { - nodes.emplace_back(node.name()); - return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor param) { params.emplace_back(param.name()); }); + HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { + nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; + }); EXPECT_THAT(nodes, ElementsAre("reduce.1", "mul", "negate")); - EXPECT_THAT(params, ElementsAre("p0", "sum")); } TEST_F(HloTraversalTest, FuseFusionProducer) { @@ -411,17 +395,13 @@ TEST_F(HloTraversalTest, FuseFusionProducer) { InstructionAdaptorName("fusion.1"))); std::vector nodes; - std::vector params; - HloBfsConsumersFirstTraversal( - fusion->GetRoots(), *fusion, - [&](HloInstructionAdaptor node) { - nodes.emplace_back(node.name()); - return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor arg) { params.emplace_back(arg.name()); }); + HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { + nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; + }); EXPECT_THAT(nodes, ElementsAre("difference", "reduce.2")); - EXPECT_THAT(params, ElementsAre("p0", "negate", "fusion.1")); } TEST_F(HloTraversalTest, FuseFusionConsumerAndProducer) { From 9cd46f5545e598545e7b4924e0104f4112b6c418 Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Tue, 30 Jul 2024 02:53:10 -0700 Subject: [PATCH 251/376] Integrate LLVM at llvm/llvm-project@0e6f64cd5e5a Updates LLVM usage to match [0e6f64cd5e5a](https://github.com/llvm/llvm-project/commit/0e6f64cd5e5a) PiperOrigin-RevId: 657511639 --- third_party/llvm/generated.patch | 18 ++++++++++++++++++ third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/workspace.bzl | 4 ++-- .../tsl/third_party/llvm/generated.patch | 18 ++++++++++++++++++ third_party/tsl/third_party/llvm/workspace.bzl | 4 ++-- .../tools/mlir_interpreter/dialects/vector.cc | 2 +- 6 files changed, 43 insertions(+), 7 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..dc682861fbb165 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,19 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp b/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp +--- a/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp ++++ b/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp +@@ -3,10 +3,10 @@ + // This test is adapted from coro-elide.cpp and splits functions into two files. + // + // RUN: split-file %s %t +-// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-callee.cpp -o coro-elide-callee.bc +-// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-caller.cpp -o coro-elide-caller.bc +-// RUN: llvm-lto --thinlto coro-elide-callee.bc coro-elide-caller.bc -o summary +-// RUN: %clang_cc1 -O2 -x ir coro-elide-caller.bc -fthinlto-index=summary.thinlto.bc -emit-llvm -o - | FileCheck %s ++// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-callee.cpp -o %t/coro-elide-callee.bc ++// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-caller.cpp -o %t/coro-elide-caller.bc ++// RUN: llvm-lto --thinlto %t/coro-elide-callee.bc %t/coro-elide-caller.bc -o %t/summary ++// RUN: %clang_cc1 -O2 -x ir %t/coro-elide-caller.bc -fthinlto-index=%t/summary.thinlto.bc -emit-llvm -o - | FileCheck %s + + //--- coro-elide-task.h + #pragma once diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index c9d1d7e658a079..f5dd4fdd0bd288 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "de5aa8d0060cbe286c9cbae90ca8f197b92a3956" - LLVM_SHA256 = "dbad8fbbf8b639a05b171d8fb467ab4b88b8dda7aa46148c7ba6c6c194ca12d7" + LLVM_COMMIT = "0e6f64cd5e5a06bd78542d5541a762154546ced3" + LLVM_SHA256 = "d3b426b13175ac771a05a0908e11391be46913fc1ab7c459ae906b07b77474c0" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 6f858b22baa66d..129c3f2e9bf708 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "052cf0bd0ddcae0e295c25dffa14f5ffe8d07d2d" - SHARDY_SHA256 = "66694f413683bdfd82b7e5a289117d8f6c6451012658274189207ff755cd9626" + SHARDY_COMMIT = "0458df554c1d569c034c10986069ec8fc1d58828" + SHARDY_SHA256 = "20b84eec31de9728b91901bf57aadf9faa9942a8b0383bd4bde9d588b51beeb1" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index 509398da979e83..dc682861fbb165 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1 +1,19 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp b/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp +--- a/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp ++++ b/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp +@@ -3,10 +3,10 @@ + // This test is adapted from coro-elide.cpp and splits functions into two files. + // + // RUN: split-file %s %t +-// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-callee.cpp -o coro-elide-callee.bc +-// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-caller.cpp -o coro-elide-caller.bc +-// RUN: llvm-lto --thinlto coro-elide-callee.bc coro-elide-caller.bc -o summary +-// RUN: %clang_cc1 -O2 -x ir coro-elide-caller.bc -fthinlto-index=summary.thinlto.bc -emit-llvm -o - | FileCheck %s ++// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-callee.cpp -o %t/coro-elide-callee.bc ++// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-caller.cpp -o %t/coro-elide-caller.bc ++// RUN: llvm-lto --thinlto %t/coro-elide-callee.bc %t/coro-elide-caller.bc -o %t/summary ++// RUN: %clang_cc1 -O2 -x ir %t/coro-elide-caller.bc -fthinlto-index=%t/summary.thinlto.bc -emit-llvm -o - | FileCheck %s + + //--- coro-elide-task.h + #pragma once diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index c9d1d7e658a079..f5dd4fdd0bd288 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "de5aa8d0060cbe286c9cbae90ca8f197b92a3956" - LLVM_SHA256 = "dbad8fbbf8b639a05b171d8fb467ab4b88b8dda7aa46148c7ba6c6c194ca12d7" + LLVM_COMMIT = "0e6f64cd5e5a06bd78542d5541a762154546ced3" + LLVM_SHA256 = "d3b426b13175ac771a05a0908e11391be46913fc1ab7c459ae906b07b77474c0" tf_http_archive( name = name, diff --git a/xla/mlir/tools/mlir_interpreter/dialects/vector.cc b/xla/mlir/tools/mlir_interpreter/dialects/vector.cc index a190c13e5a4ac9..7aaaf5af97215e 100644 --- a/xla/mlir/tools/mlir_interpreter/dialects/vector.cc +++ b/xla/mlir/tools/mlir_interpreter/dialects/vector.cc @@ -230,7 +230,7 @@ InterpreterValue MaskImpl(mlir::Operation* op, ArrayRef mask_sizes) { } InterpreterValue ConstantMask(InterpreterState&, vector::ConstantMaskOp mask) { - return MaskImpl(mask, ExtractVector(mask.getMaskDimSizes())); + return MaskImpl(mask, mask.getMaskDimSizes()); } // TODO(jreiffers): Support masked contractions. From 2d47ca58d6f0ba36e9806851dfbdfb02a3d51b74 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Tue, 30 Jul 2024 04:38:15 -0700 Subject: [PATCH 252/376] Add IsMinNormal and IsSubnormalOrMinNormal utility functions for exhaustive tests PiperOrigin-RevId: 657538740 --- .../exhaustive/exhaustive_op_test_utils.cc | 14 ++++++ .../exhaustive/exhaustive_op_test_utils.h | 49 ++++++++++++++++++- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/xla/tests/exhaustive/exhaustive_op_test_utils.cc index 4f606a11dc0220..991c550e8b925f 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -45,6 +45,12 @@ bool IsSubnormalReal(xla::complex128 value) { return IsSubnormal(value.real()); } +bool IsMinNormalReal(xla::complex64 value) { return IsMinNormal(value.real()); } + +bool IsMinNormalReal(xla::complex128 value) { + return IsMinNormal(value.real()); +} + bool IsSubnormalImaginary(xla::complex64 value) { return IsSubnormal(value.imag()); } @@ -53,6 +59,14 @@ bool IsSubnormalImaginary(xla::complex128 value) { return IsSubnormal(value.imag()); } +bool IsMinNormalImaginary(xla::complex64 value) { + return IsMinNormal(value.imag()); +} + +bool IsMinPositiveImaginary(xla::complex128 value) { + return IsMinNormal(value.imag()); +} + // For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of // precision to be guaranteed that we're printing the full number. // diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.h b/xla/tests/exhaustive/exhaustive_op_test_utils.h index 80c69703dfb96c..5504f4f9e3f5ba 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -55,19 +55,38 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { +// Determines if the real component of the complex number is subnormal (either +// sign). // Determines if the real component of the complex number is subnormal. // // See also IsSubnormal to check if either component is subnormal. bool IsSubnormalReal(xla::complex64); bool IsSubnormalReal(xla::complex128); -// Determines if the imaginary component of the complex number is subnormal. +// Determines if the real component of the complex number is the minimum +// normal floating point value (either sign). +// +// See also IsMinPositive to check if either component is the minimum normal +// floating point value. +bool IsMinNormalReal(xla::complex64); +bool IsMinNormalReal(xla::complex128); + +// Determines if the imaginary component of the complex number is subnormal +// (either sign). // // See also IsSubnormal to check if either component is subnormal. bool IsSubnormalImaginary(xla::complex64); bool IsSubnormalImaginary(xla::complex128); -// Determines if the NativeT is subnormal. +// Determines if the imaginary component of the complex number is the minimum +// normal floating point value (either sign). +// +// See also IsMinPositive to check if either component is the minimum normal +// floating point value. +bool IsMinNormalImaginary(xla::complex64); +bool IsMinNormalImaginary(xla::complex128); + +// Determines if the NativeT is subnormal (either sign). // // For complex numbers, this will return true if either real or imaginary // component is subnormal. See IsSubnormalReal and IsSubnormalImaginary if you @@ -82,6 +101,32 @@ bool IsSubnormal(NativeT value) { } } +// Determines if the NativeT is the minimum normal floating point value +// (either sign). +// +// For complex numbers, this will return true if either real or imaginary +// component is the minimum normal floating point value. See IsMinPositiveReal +// and IsMinPositiveImaginary if you only care about one component. +template +bool IsMinNormal(NativeT value) { + if constexpr (std::is_same_v || + std::is_same_v) { + return IsMinNormalReal(value) || IsMinNormalImaginary(value); + } else { + return std::abs(value) == std::numeric_limits::min(); + } +} + +// Determines if the NativeT is subnormal or the minimum normal floating point +// value (either sign). +// +// For complex numbers, this will return true if either real or imaginary +// component is subnormal or the minimum normal floating point value. +template +bool IsSubnormalOrMinNormal(NativeT value) { + return IsSubnormal(value) || IsMinNormal(value); +} + struct ErrorSpec { double abs_err = 0; double rel_err = 0; From 3fc6bab16427a177f86bdb82f8fa2f9c42786777 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 30 Jul 2024 05:03:04 -0700 Subject: [PATCH 253/376] PR #15445: [GPU][NFC] Log progress of sharded GEMM fusion autotuning. Imported from GitHub PR https://github.com/openxla/xla/pull/15445 Copybara import of the project: -- e43fb81f5c81090bc51fd0faf9892720e87c4f06 by Ilia Sergachev : [GPU][NFC] Log progress of sharded GEMM fusion autotuning. Merging this change closes #15445 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15445 from openxla:log_sharded_autotuning_progress e43fb81f5c81090bc51fd0faf9892720e87c4f06 PiperOrigin-RevId: 657545291 --- xla/service/gpu/gemm_fusion_autotuner.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xla/service/gpu/gemm_fusion_autotuner.cc b/xla/service/gpu/gemm_fusion_autotuner.cc index 0a6188495febf2..5a5a0f3671342e 100644 --- a/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/gemm_fusion_autotuner.cc @@ -1176,10 +1176,13 @@ absl::Status ExchangeResults(KeyValueStoreInterface& key_value_store, TF_RETURN_IF_ERROR(key_value_store.Set( absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, shard_index), results_str)); + VLOG(2) << "Rank " << shard_index << ": published results"; for (int i = 0; i < shard_count; ++i) { if (i == shard_index) { continue; } + VLOG(2) << "Rank " << shard_index << ": waiting for results from rank " << i + << " / " << shard_count; TF_ASSIGN_OR_RETURN( std::string autotune_results_str, key_value_store.Get( From dd05f78df55e063229b57f0ab9c32419d56ced65 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Tue, 30 Jul 2024 05:06:13 -0700 Subject: [PATCH 254/376] Have a version of the shardings as mhlo.shardings during SDY round-tripping. PiperOrigin-RevId: 657546125 --- xla/service/spmd/shardy/sdy_round_trip/BUILD | 1 + .../shardy/sdy_round_trip/export_shardings.cc | 6 ----- .../shardy/sdy_round_trip/export_shardings.h | 4 +++ .../shardy/sdy_round_trip/import_shardings.cc | 9 +++++++ .../spmd/shardy/sdy_round_trip/pipelines.cc | 5 ++++ xla/service/spmd/shardy/shardy_xla_pass.cc | 1 - .../test/sdy_round_trip_export_pipeline.mlir | 25 ++++++++++--------- 7 files changed, 32 insertions(+), 19 deletions(-) diff --git a/xla/service/spmd/shardy/sdy_round_trip/BUILD b/xla/service/spmd/shardy/sdy_round_trip/BUILD index 2784d2d02b5963..ab41ada3a7e058 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -79,6 +79,7 @@ cc_library( ":export_shardings", ":import_shardings", "//xla/service:hlo_proto_cc", + "//xla/service/spmd/shardy/mhlo_round_trip:export_shardings", "//xla/service/spmd/shardy/round_trip_common:pipeline_passes", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", diff --git a/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc b/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc index aec0a20775c73a..b076d5b215785c 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc @@ -92,7 +92,6 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { if (auto oldSharding = funcOp.getArgAttrOfType( argNum, kShardingAttr)) { addFrontendAttribute(funcOp, kShardingRoundTripAttr, oldSharding, argNum); - funcOp.removeArgAttr(argNum, kShardingAttr); } } @@ -122,7 +121,6 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { TensorShardingPerValueAttr::get(customCallOp.getContext(), sharding), builder); returnOperand.set(customCallOp.getResult(0)); - funcOp.removeResultAttr(resultNum, builder.getStringAttr(kShardingAttr)); } } @@ -130,7 +128,6 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { if (auto oldShardingPerValue = op->getAttrOfType(kShardingAttr)) { saveOpShardingPerValueAttr(op, oldShardingPerValue, builder); - op->removeAttr(kShardingAttr); } }); @@ -155,8 +152,6 @@ class SdyRoundTripExportShardingsPass } SmallVector mhloMeshes; - mlir::SymbolTableCollection symbolTableCollection; - SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp); // Saves the MeshOps for MHLO<->HLO round-trip and removes them from the // ModuleOp. for (MeshOp meshOp : @@ -164,7 +159,6 @@ class SdyRoundTripExportShardingsPass mhloMeshes.emplace_back( meshOp.getSymNameAttr(), getStringAttribute(meshOp.getMeshAttr(), builder)); - symbolTable.erase(meshOp); } addFrontendAttribute(moduleOp, kMeshesRoundTripAttr, DictionaryAttr::get(context, mhloMeshes)); diff --git a/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h b/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h index dfbe7108694147..4b8ce6ab737419 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h +++ b/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h @@ -29,6 +29,10 @@ void registerSdyRoundTripExportShardingsPass(); // Creates the pass that converts the shardings from `kShardingAttr` to // `kShardingRoundTripAttr` in the HLO frontend attributes and saves the // mesh symbols as `kMeshesRoundTripAttr` in the module frontend attributes. +// +// NOTE: The `kShardingAttr`s are not removed from the ops. They are kept around +// because part of the `SdyRoundTripExportPipeline` it also converts the +// `kShardingAttr`s to `kXlaShardingAttr`s. std::unique_ptr createSdyRoundTripExportShardingsPass(); } // namespace sdy diff --git a/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc b/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc index a5347a3b416c65..28cdc89c7c1125 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc @@ -96,6 +96,7 @@ void convertShardings(FuncOp funcOp) { // We need to wait until after we've converted all the Operations before // copying the result shardings. for (auto [argNum, argType] : llvm::enumerate(funcOp.getArgumentTypes())) { + funcOp.removeArgAttr(argNum, kXlaShardingAttr); // Attempt to extract the TensorShardingAttr from the frontend attributes of // the function argument/result. if (DictionaryAttr dictAttr = getFuncArgFrontendAttrs(funcOp, argNum)) { @@ -106,8 +107,16 @@ void convertShardings(FuncOp funcOp) { } } + // Due to `SdyRoundTripExportShardingsPass` keeping `mhlo.sharding`s, remove + // them purely for cleanliness of the module. + for (int64_t resNum = 0; resNum < funcOp.getNumResults(); ++resNum) { + funcOp.removeResultAttr( + resNum, StringAttr::get(funcOp.getContext(), kXlaShardingAttr)); + } + // Extract the round-tripped SDY shardings from the operations. funcOp.front().walk([&](Operation* op) { + op->removeAttr(kXlaShardingAttr); if (DictionaryAttr dictAttr = getFrontendAttrs(op)) { // NOTE: we are only setting the sharding on known custom-calls. For any // other op that has a `kShardingRoundTripAttr` we discard it. XLA diff --git a/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc b/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc index f9eda62025d762..ee348edad68c17 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" #include "xla/service/hlo.pb.h" +#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" @@ -38,7 +39,11 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) { // `createSdyRoundTripExportShardingsPass` and make use of // `createSdyRoundTripImportShardingsPass` to import them. pm.addPass(createSdyRoundTripExportOpsPass()); + // Preserve the SDY shardings for `createExportMhloShardingsPass` so that + // we have both `mhlo.sharding`s and hidden `sdy.sharding`s on the module. We + // want to have `mhlo.sharding`s for Pathways to read from. pm.addPass(createSdyRoundTripExportShardingsPass()); + pm.addPass(createExportMhloShardingsPass()); } void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) { diff --git a/xla/service/spmd/shardy/shardy_xla_pass.cc b/xla/service/spmd/shardy/shardy_xla_pass.cc index 1735b3ccc30985..2514c46d91d3b2 100644 --- a/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/spmd/shardy/shardy_xla_pass.h" #include -#include #include #include #include diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir index 58fa8f829f59a9..8853f1b498bd7d 100644 --- a/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir +++ b/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir @@ -13,15 +13,15 @@ sdy.mesh @mesh_2 = <"x"=8, "y"=4> // CHECK-SAME: mesh_2 = \22#sdy.mesh<\\22x\\22=8, \\22y\\22=4>\22}"}} { // CHECK-LABEL: func @multiple_shardings( -// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{\22axis_2\22}, {\22axis_0\22, \22axis_1\22}]>"}}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_0\22, \22axis_2\22}]>"}}, -// CHECK-SAME: %arg2: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_1\22}]>"}}) +// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{\22axis_2\22}, {\22axis_0\22, \22axis_1\22}]>"}, mhlo.sharding = +// CHECK-SAME: %arg1: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_0\22, \22axis_2\22}]>"}, mhlo.sharding = +// CHECK-SAME: %arg2: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_1\22}]>"}, mhlo.sharding = // CHECK-SAME: -> tensor<8x16xf32> { func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { // CHECK-NEXT: mhlo.add -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_0, [{\22axis_1\22, \22axis_0\22}, {}]>]>"}} +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_0, [{\22axis_1\22, \22axis_0\22}, {}]>]>"}, mhlo.sharding = %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> @@ -31,7 +31,7 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.shardi func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { %0 = mhlo.constant dense<0.000000e+00> : tensor // CHECK: mhlo.reduce -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {\22y\22}]>, <@mesh_2, [{\22y\22}, {}]>]>"}} +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {\22y\22}]>, <@mesh_2, [{\22y\22}, {}]>]>"}, mhlo.sharding = %1:2 = mhlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{}, {"y"}]>, <@mesh_2, [{"y"}, {}]>]>} : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) @@ -44,20 +44,20 @@ func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) } // CHECK-LABEL: func @split_axes( -// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22y\22}, {\22x\22:(2)2}]>"}}, -// CHECK-SAME: %arg1: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22x\22:(1)2}, {\22x\22:(2)4}]>"}}) +// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22y\22}, {\22x\22:(2)2}]>"}, mhlo.sharding = +// CHECK-SAME: %arg1: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22x\22:(1)2}, {\22x\22:(2)4}]>"}, mhlo.sharding = // CHECK-SAME: -> tensor<8x16xf32> { func.func @split_axes(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {"x":(2)2}]>}, %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x":(1)2}, {"x":(2)4}]>}) -> tensor<8x16xf32> { // CHECK-NEXT: "mhlo.dot" -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22:(1)2, \22x\22:(4)2}, {}]>]>"}} +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22:(1)2, \22x\22:(4)2}, {}]>]>"}, mhlo.sharding = %1 = "mhlo.dot" (%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } // CHECK-LABEL: func @func_result_sharding_returning_func_arg( func.func @func_result_sharding_returning_func_arg( - // CHECK: %arg0: tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {mhlo.sharding = %arg0: tensor<8x16xf32> ) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}) { // CHECK: %[[CUSTOM_CALL:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {\22y\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> @@ -67,10 +67,11 @@ func.func @func_result_sharding_returning_func_arg( // CHECK-LABEL: func @func_result_sharding_returning_op_value( func.func @func_result_sharding_returning_op_value( - // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>) { + // CHECK: %arg0: tensor<8x16xf32>) + // CHECK-SAME: -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}, tensor<8x16xf32> {mhlo.sharding = "{devices=[1,4,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) { // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = mhlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, \22y\22}, {}]>, <@mesh_2, [{\22y\22, \22x\22}, {}]>]>"}} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) + // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = mhlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, \22y\22}, {}]>, <@mesh_2, [{\22y\22, \22x\22}, {}]>]>"}, mhlo.sharding = // CHECK-NEXT: %[[ADD_RESULT_SHARDING:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {\22y\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_0:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{?}, {\22y\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_1:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22}, {\22y\22}p1]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> @@ -83,7 +84,7 @@ func.func @func_result_sharding_returning_op_value( // CHECK-LABEL: func @sharding_constraint // CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.func @sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {?}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK: mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {?}]>]>"}, mhlo.sharding = %0 = sdy.sharding_constraint %arg0 <@mesh_2, [{"x", ?}, {?}]> : tensor<8x8xf32> return %0 : tensor<8x8xf32> } From ea5422132c2121453614951d499c251d782e6e27 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 30 Jul 2024 05:07:20 -0700 Subject: [PATCH 255/376] Disable two failing tests. There is a bug in the generated parser for CHECK_ExpectCloseOp which became visible after updating the LLVM revision. PiperOrigin-RevId: 657546340 --- third_party/stablehlo/temporary.patch | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b137891791fe9..6a1b6b065a1370 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,19 @@ +diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir +--- stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir ++++ stablehlo/stablehlo/tests/math/ulp_difference_float32.mlir +@@ -1,4 +1,5 @@ + // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret ++// XFAIL: * + // This file is generated, see build_tools/math/README.md for more information. + module @ulp_difference_float32 { + func.func public @main() { +diff --ruN a/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir b/stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir +--- stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir ++++ stablehlo/stablehlo/tests/math/ulp_difference_float64.mlir +@@ -1,4 +1,5 @@ + // RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret ++// XFAIL: * + // This file is generated, see build_tools/math/README.md for more information. + module @ulp_difference_float64 { + func.func public @main() { From 3be611e4a259f663535f98c045fe70d7c6f59718 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 30 Jul 2024 05:24:33 -0700 Subject: [PATCH 256/376] Improve PreBufferAssignmentFusionInfo to check for DUS in-place conditions. Currently it assumes that a DUS can always be emitted in-place. With recent refactorings it becomes possible to check almost all conditions for whether a DUS fusion can be done in-place (the only missing piece is to check whether the buffer is shared between DUS operand and DUS output). In order for this to work correctly also for FusionAdaptors which are ProducerConsumer fusions, we need to work with HloInstructionAdaptor instead of HloInstruction. PiperOrigin-RevId: 657550870 --- xla/service/gpu/BUILD | 10 +- xla/service/gpu/fusions/fusions.cc | 4 +- xla/service/gpu/fusions/fusions.h | 5 +- .../fusions/in_place_dynamic_update_slice.cc | 8 +- .../fusions/in_place_dynamic_update_slice.h | 4 +- .../in_place_dynamic_update_slice_mlir.cc | 9 +- .../in_place_dynamic_update_slice_mlir.h | 2 +- xla/service/gpu/hlo_traversal.cc | 5 + xla/service/gpu/hlo_traversal.h | 2 + xla/service/gpu/ir_emission_utils.cc | 163 ++++++++---------- xla/service/gpu/ir_emission_utils.h | 14 +- xla/service/gpu/ir_emission_utils_test.cc | 47 +++-- 12 files changed, 139 insertions(+), 134 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index a8b7d3a18d6fb1..58b83a2ca7922e 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -786,18 +786,14 @@ cc_library( "//xla/hlo/ir:backend_config", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/service:hlo_parser", - "//xla/service/llvm_ir:buffer_assignment_util", "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/service/llvm_ir:llvm_util", - "//xla/translate/mhlo_to_hlo:location_exporter", - "//xla/translate/mhlo_to_hlo:type_to_shape", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -805,7 +801,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", "@tsl//tsl/lib/strings:proto_serialization", - "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:statusor", ], ) @@ -819,10 +815,10 @@ xla_cc_test( ":ir_emission_utils", "//xla:literal", "//xla:literal_util", + "//xla:shape_util", "//xla:types", "//xla:util", "//xla/hlo/ir:backend_config", - "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep diff --git a/xla/service/gpu/fusions/fusions.cc b/xla/service/gpu/fusions/fusions.cc index 00835aaaf7fd84..03425b525f0ab4 100644 --- a/xla/service/gpu/fusions/fusions.cc +++ b/xla/service/gpu/fusions/fusions.cc @@ -95,11 +95,11 @@ std::optional> HloFusionInfo::GetCopyFusion() bool HloFusionInfo::CanEmitDynamicUpdateSliceInPlace() const { auto ret = CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - instr_, + analysis().fusion(), [this](const HloInstruction* instruction, const ShapeIndex& index) { return GetAllocationSlice(*buffer_assignment_, instruction, index); }, - analysis().fusion_roots()); + instr_); return ret.ok() && *ret; } diff --git a/xla/service/gpu/fusions/fusions.h b/xla/service/gpu/fusions/fusions.h index 9011c80d7f9f43..f7406b463b9117 100644 --- a/xla/service/gpu/fusions/fusions.h +++ b/xla/service/gpu/fusions/fusions.h @@ -73,8 +73,9 @@ class PreBufferAssignmentFusionInfo : public FusionInfo { : FusionInfo(analysis) {} bool CanEmitDynamicUpdateSliceInPlace() const override { - // Optimistically assume all DUS fusions are in-place. - return true; + auto ret = CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + analysis().fusion(), /*get_allocation_slice=*/{}); + return ret.value_or(false); } std::optional> GetCopyFusion() diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc index e362398cea60b1..464de3c81e2371 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc @@ -41,7 +41,7 @@ constexpr int kDUSUpdateIndex = 1; } // namespace LaunchDimensions InPlaceDynamicUpdateSliceFusion::launch_dimensions() const { - const auto& update_shape = dus_ops_.front()->operand(1)->shape(); + const auto& update_shape = dus_ops_.front().GetOperand(1).shape(); return CalculateLaunchDimensions(update_shape, analysis_.device_info()); } @@ -55,7 +55,7 @@ InPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( auto launch_dims = launch_dimensions(); // It is guaranteed that all DUS ops have the same output shape at this point. const auto& update_shape = - dus_ops_.front()->operand(kDUSUpdateIndex)->shape(); + dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1, update_shape, mlir_context); } @@ -72,7 +72,7 @@ absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel( // This condition should be enforced explicitly in the // 'CanEmitFusedDynamicUpdateSliceInPlaceForGpu' matcher. for (auto [op, output] : llvm::zip(dus_ops_, outputs)) { - output = output.CastToShape(op->shape(), builder); + output = output.CastToShape(op.shape(), builder); } auto* fused_computation = fusion.fused_instructions_computation(); @@ -93,7 +93,7 @@ absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel( dus_and_output_array.reserve(dus_ops_.size()); for (auto [op, output] : llvm::zip(dus_ops_, outputs)) { - dus_and_output_array.push_back(std::make_pair(op, output)); + dus_and_output_array.push_back(std::make_pair(&op.instruction(), output)); } return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice.h b/xla/service/gpu/fusions/in_place_dynamic_update_slice.h index 08bcef5f9c2b5c..cfac87dccc1a5e 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice.h +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice.h @@ -78,7 +78,7 @@ class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase { std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* indexing_context) const override; + mlir::MLIRContext* mlir_context) const override; protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, @@ -89,7 +89,7 @@ class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase { llvm::IRBuilder<>* builder) const override; const HloFusionAnalysis& analysis_; - std::vector dus_ops_; + std::vector dus_ops_; }; } // namespace gpu diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc index cc9d10ec7decaa..885d745b9a7978 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc @@ -68,7 +68,7 @@ constexpr int kDUSUpdateIndex = 1; LaunchDimensions MlirInPlaceDynamicUpdateSliceFusion::launch_dimensions() const { const auto& update_shape = - dus_ops_.front()->operand(kDUSUpdateIndex)->shape(); + dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); return CalculateLaunchDimensions(update_shape, analysis_.device_info()); } @@ -83,7 +83,7 @@ MlirInPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( auto launch_dims = launch_dimensions(); // It is guaranteed that all DUS ops have the same output shape at this point. const auto& update_shape = - dus_ops_.front()->operand(kDUSUpdateIndex)->shape(); + dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1, update_shape, indexing_context); } @@ -98,7 +98,7 @@ MlirInPlaceDynamicUpdateSliceFusion::GetEpilogues( llvm::zip(dus_ops_, analysis_.fusion_roots())) { epilogues.push_back( mlir_converter::EpilogueSpecification::FromIdentityIndexing( - dus_op, &root.instruction(), mlir_context)); + &dus_op.instruction(), &root.instruction(), mlir_context)); } return epilogues; } @@ -133,7 +133,8 @@ absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction( llvm::SmallVector results; for (auto [instr, root, output] : llvm::zip(dus_ops_, analysis_.fusion_roots(), output_tensors)) { - const auto* dus_instr = Cast(instr); + const auto* dus_instr = + Cast(&instr.instruction()); const auto& update_shape = dus_instr->update()->shape(); SmallVector update_indices; auto start_indices = ProvideParameterRange( diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h index e1a5bc5310e88a..2ed84a06522b16 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h @@ -76,7 +76,7 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase { private: const HloFusionAnalysis& analysis_; - std::vector dus_ops_; + std::vector dus_ops_; }; } // namespace gpu diff --git a/xla/service/gpu/hlo_traversal.cc b/xla/service/gpu/hlo_traversal.cc index dfa655276b0fdf..c2318ab10f1584 100644 --- a/xla/service/gpu/hlo_traversal.cc +++ b/xla/service/gpu/hlo_traversal.cc @@ -500,6 +500,11 @@ bool operator==(const HloInstructionAdaptor& lhs, lhs.instruction_->unique_id() == rhs.instruction_->unique_id(); } +bool operator!=(const HloInstructionAdaptor& lhs, + const HloInstructionAdaptor& rhs) { + return !(lhs == rhs); +} + namespace { void HloBfsTraversal( absl::Span roots, diff --git a/xla/service/gpu/hlo_traversal.h b/xla/service/gpu/hlo_traversal.h index 67edd2258bb563..b4a5859875ba24 100644 --- a/xla/service/gpu/hlo_traversal.h +++ b/xla/service/gpu/hlo_traversal.h @@ -53,6 +53,8 @@ class HloInstructionAdaptor { friend bool operator==(const HloInstructionAdaptor& lhs, const HloInstructionAdaptor& rhs); + friend bool operator!=(const HloInstructionAdaptor& lhs, + const HloInstructionAdaptor& rhs); template friend H AbslHashValue(H h, const HloInstructionAdaptor& m); diff --git a/xla/service/gpu/ir_emission_utils.cc b/xla/service/gpu/ir_emission_utils.cc index 81d05f4d1347fa..e2882a7e8c9366 100644 --- a/xla/service/gpu/ir_emission_utils.cc +++ b/xla/service/gpu/ir_emission_utils.cc @@ -28,7 +28,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/status/status.h" #include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -42,6 +41,7 @@ limitations under the License. #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -54,18 +54,16 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/target_util.h" -#include "xla/service/llvm_ir/buffer_assignment_util.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/strings/proto_serialization.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" namespace xla { @@ -369,16 +367,16 @@ absl::StatusOr GetAllocationSlice( return buffer_assignment.GetUniqueSlice(instr, index); } -std::vector GetOutputDefiningDynamicUpdateSlices( +std::vector GetOutputDefiningDynamicUpdateSlices( absl::Span roots) { - std::vector dus_ops; + std::vector dus_ops; for (HloInstructionAdaptor root : roots) { while (root.opcode() == HloOpcode::kBitcast) { root = root.GetOperand(0); } if (root.opcode() == HloOpcode::kDynamicUpdateSlice) { - dus_ops.push_back(&root.instruction()); + dus_ops.push_back(root); } } return dus_ops; @@ -396,109 +394,86 @@ absl::InlinedVector GetStartIndices(T instr) { } absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - const HloFusionInstruction* fusion, + const HloFusionAdaptor& fusion_adaptor, std::function( const HloInstruction* instr, const ShapeIndex& index)> get_allocation_slice, - absl::Span roots) { - std::vector dus_instrs = - GetOutputDefiningDynamicUpdateSlices(roots); - - // Get output buffers for fusion. - std::vector output_buffers; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - fusion->shape(), [&](const Shape& shape, const ShapeIndex index) { - if (shape.IsArray()) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, - get_allocation_slice(fusion, index)); - output_buffers.push_back(buffer); - } - return absl::OkStatus(); - })); + const HloInstruction* fusion) { + std::vector dus_instrs = + GetOutputDefiningDynamicUpdateSlices(fusion_adaptor.GetRoots()); // This check could probably be relaxed: if code generation is made to use a // separate parallel loop for each dynamic slice update, then it shouldn't be // necessary for every output to be a dynamic slice update, nor to have the // same shape. - if (dus_instrs.size() != output_buffers.size()) { + if (dus_instrs.size() != fusion_adaptor.GetRoots().size()) { return false; } - if (output_buffers.empty()) { - return Internal("Output buffers should not be empty"); - } - - Shape update_shape = dus_instrs[0]->operand(1)->shape(); + Shape update_shape = dus_instrs[0].GetOperand(1).shape(); for (int i = 0; i < dus_instrs.size(); ++i) { - auto* dus = Cast(dus_instrs[i]); + const auto& dus = dus_instrs[i]; - // Dynamic slice updates should have a single path to the root to avoid + // DynamicUpdateSlice ops should have a single path to the root to avoid // allowing a dynamic slice update to depend on another, as this would not // be guaranteed to work with the current codegen. - if (!dus->IsRoot() && dus->user_count() != 1) return false; - - // We follow DUS users until we find a root instruction. We support only - // few patterns: + // We follow DUS users until we find an instruction without users. We + // support only few patterns: // // (1) ROOT dynamic-update-slice // (2) ROOT tuple(dynamic-update-slice) // (3) ROOT bitcast(dynamic-update-slice) // (4) ROOT tuple(bitcast(dynamic-update-slice)) - HloInstruction* dus_user = dus->IsRoot() ? nullptr : dus->users().front(); - - // Since the direct consumer of an output dynamic slice update may be a - // bitcast, we also check that this bitcast is used a single time. - // This property is also important because reads and writes on the parameter - // to be updated are done using the shape and layout of the dynamic slice - // update. This is a valid approach only if a subsequent bitcast is not read - // by any other op within the fusion as this may result in codegen - // accessing elements using the wrong physical layout. - if (dus_user && dus_user->opcode() == HloOpcode::kBitcast) { - if (!dus_user->IsRoot() && dus_user->user_count() != 1) return false; - - // Stop following DUS users if we found a root. - dus_user = dus_user->IsRoot() ? nullptr : dus_user->users().front(); - } - - // Check that last DUS user is a tuple operation at ROOT position. - if (dus_user && dus_user->opcode() == HloOpcode::kTuple) { - if (!dus_user->IsRoot()) return false; - - // Stop following DUS users if we found a root. - dus_user = nullptr; + // + // In case there is a root tuple, the search will stop at the tuple operand, + // as the root tuple is not considered a real user by HloInstructionAdaptor. + // Note that due to AlgebraicSimplifier we will never have a chain of + // bitcasts. + HloInstructionAdaptor real_root = dus; + auto users = real_root.GetUsers(); + while (!users.empty()) { + if (users.size() > 1) { + return false; + } + real_root = users.front(); + if (real_root.opcode() != HloOpcode::kBitcast) { + return false; + } + users = real_root.GetUsers(); } - // We can't emit DUS fusion if we have unsupported DUS users. - if (dus_user != nullptr) return false; - // Find "real" DUS operand by skipping bitcasted operands. - const HloInstruction* operand = dus->operand(0); - if (operand->opcode() == HloOpcode::kBitcast) { - operand = operand->operand(0); + HloInstructionAdaptor operand = dus.GetOperand(0); + if (fusion_adaptor.ContainsInstruction(operand) && + operand.opcode() == HloOpcode::kBitcast) { + operand = operand.GetOperand(0); } // Operand to a DUS (or Bitcast) must be a fusion parameter. - auto* parameter = DynCast(operand); - if (!parameter) return false; + // HloInstructionAdaptor skips parameters, so we need to check whether + // 'operand' is outside of the fusion. + if (fusion_adaptor.ContainsInstruction(operand)) { + return false; + } // We require that the parameter being updated is only read at the same // index positions by all users, since we otherwise risk a race condition // when updating the parameter inplace. - std::queue q; + std::queue q; absl::flat_hash_set visited; - q.push(parameter); - visited.insert(parameter); + q.push(operand); + visited.insert(&operand.instruction()); // We have already checked above that the DUS only has one user. So we don't // need to visit it during the breadth-first search. - visited.insert(dus); + visited.insert(&dus.instruction()); while (!q.empty()) { - const HloInstruction* instr = q.front(); + HloInstructionAdaptor instr = q.front(); q.pop(); - for (const HloInstruction* user : instr->users()) { - if (user->opcode() == HloOpcode::kDynamicSlice && - dus->operand(0) == user->operand(0) && - update_shape == user->shape()) { + for (const HloInstructionAdaptor& user : instr.GetUsers()) { + if (user.opcode() == HloOpcode::kDynamicSlice && + dus.GetOperand(0) == user.GetOperand(0) && + update_shape == user.shape()) { // We can still emit in-place in this case if the same slice is // accessed by the DUS and the DS. If they don't access the same // slice, the two slices might partially overlap and read/write the @@ -506,19 +481,21 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( // read before it is overwritten. However if both access only a single // element, there also can be no race condition. absl::InlinedVector user_start_indices = - GetStartIndices(Cast(user)); + GetStartIndices( + Cast(&user.instruction())); absl::InlinedVector dus_start_indices = - GetStartIndices(dus); + GetStartIndices( + Cast(&dus.instruction())); if (ShapeUtil::ElementsIn(update_shape) != 1 && user_start_indices != dus_start_indices) { return false; } - } else if (user != dus && !user->IsElementwise() && - user->opcode() != HloOpcode::kBitcast && - user->opcode() != HloOpcode::kTuple) { + } else if (user != dus && !user.instruction().IsElementwise() && + user.opcode() != HloOpcode::kBitcast && + user.opcode() != HloOpcode::kTuple) { return false; } - if (visited.insert(user).second) { + if (visited.insert(&user.instruction()).second) { q.push(user); } } @@ -529,16 +506,26 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( // be necessary for the shape to be the same for all the dynamic slice // updates. Note that this equality check purposefully ignores the element // type. - if (dus->update()->shape() != update_shape) { + if (Cast(&dus.instruction()) + ->update() + ->shape() != update_shape) { return false; } - const HloInstruction* lhs = fusion->operand(parameter->parameter_number()); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, - get_allocation_slice(lhs, {})); - BufferAllocation::Slice rhs_buffer = output_buffers[i]; - if (lhs_buffer != rhs_buffer) { - return false; + if (fusion != nullptr) { + ShapeIndex root_index = {}; + if (fusion->IsMultiOutputFusion()) { + root_index = {i}; + } + // Get output buffer for the fusion root. + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer, + get_allocation_slice(fusion, root_index)); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, + get_allocation_slice(&operand.instruction(), {})); + if (lhs_buffer != output_buffer) { + return false; + } } } diff --git a/xla/service/gpu/ir_emission_utils.h b/xla/service/gpu/ir_emission_utils.h index 044a3537d90282..9316ba9a655131 100644 --- a/xla/service/gpu/ir_emission_utils.h +++ b/xla/service/gpu/ir_emission_utils.h @@ -125,21 +125,25 @@ absl::StatusOr GetAllocationSlice( const BufferAssignment& buffer_assignment, const HloInstruction* instr, const ShapeIndex& index); -// Returns whether 'fusion' can be emitted with the dynamic update slice -// in-place emitter. +// Returns whether the fusion represented by 'fusion_adaptor' can be emitted +// with the dynamic update slice in-place emitter. If 'fusion_adaptor' +// represents a single fusion computation, 'fusion' should provide the fusion +// instruction corresponding to that fusion computation. 'get_allocation_slice' +// is a callback for getting the allocated buffer slice, given an instruction +// and a shape index. This is ignored in case 'fusion' is a nullptr. absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - const HloFusionInstruction* fusion, + const HloFusionAdaptor& fusion_adaptor, std::function( const HloInstruction* instr, const ShapeIndex& index)> get_allocation_slice, - absl::Span roots); + const HloInstruction* fusion = nullptr); // Returns the dynamic-update-slice instructions defining the results of a // fusion node. A dynamic slice update is said to be "defining" of a result if // that result is the output of a dynamic slice update, or if that result is the // output of a bitcast of a dynamic slice update---since such bitcast may be // handled as a no-op. -std::vector GetOutputDefiningDynamicUpdateSlices( +std::vector GetOutputDefiningDynamicUpdateSlices( absl::Span roots); // Returns the first hero instruction reachable from `instr` as root. Hero diff --git a/xla/service/gpu/ir_emission_utils_test.cc b/xla/service/gpu/ir_emission_utils_test.cc index 67ffe2c723fd4f..0703f869d1457c 100644 --- a/xla/service/gpu/ir_emission_utils_test.cc +++ b/xla/service/gpu/ir_emission_utils_test.cc @@ -22,12 +22,12 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/hlo/ir/backend_config.h" -#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/types.h" #include "xla/util.h" @@ -703,12 +703,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -742,12 +743,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(false)); } @@ -782,8 +784,9 @@ ENTRY main { BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); BufferAllocation::Slice slice1(&alloc, 10, 20); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [fusion, &slice0, &slice1](const HloInstruction* instr, const ShapeIndex&) { if (instr == fusion) { @@ -791,7 +794,7 @@ ENTRY main { } return slice1; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(false)); } @@ -825,12 +828,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(false)); } @@ -868,12 +872,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -913,12 +918,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -954,12 +960,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -995,12 +1002,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -1038,12 +1046,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } From ebbef62e7b71df64db05ab5ee9522b4f633c0366 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 06:27:32 -0700 Subject: [PATCH 257/376] Fix tests depending on Triton that run in CUDA 11.0 PiperOrigin-RevId: 657568631 --- .../triton/temporary/cuda11-temporary.patch | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 third_party/triton/temporary/cuda11-temporary.patch diff --git a/third_party/triton/temporary/cuda11-temporary.patch b/third_party/triton/temporary/cuda11-temporary.patch new file mode 100644 index 00000000000000..2c97b606a0ff7a --- /dev/null +++ b/third_party/triton/temporary/cuda11-temporary.patch @@ -0,0 +1,36 @@ +# This temporary patch has already been included to the public list of Triton +# patches. It is only here temporarily to be included in the openxla version, +# but it will be removed during the next triton integration. + +diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backend/cuda_utils.cc +--- a/third_party/nvidia/backend/cuda_utils.cc ++++ b/third_party/nvidia/backend/cuda_utils.cc +@@ -587,6 +587,8 @@ static PyObject *loadBinary(PyObject *se + typedef CUresult (*cuOccupancyMaxActiveClusters_t)( + int *numClusters, CUfunction func, const CUlaunchConfig *config); + ++#if CUDA_VERSION < 12000 ++#else + typedef CUresult (*cuTensorMapEncodeTiled_t)( + CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, +@@ -594,6 +596,7 @@ typedef CUresult (*cuTensorMapEncodeTile + const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill); ++#endif + + #define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ +@@ -620,8 +623,11 @@ typedef CUresult (*cuTensorMapEncodeTile + defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, + cuOccupancyMaxActiveClusters); + ++#if CUDA_VERSION < 12000 ++#else + defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, + cuTensorMapEncodeTiled); ++#endif + + static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { + int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, From 33eacb20a0cdbaf0594a8e6aaeef7ded6d076369 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Tue, 30 Jul 2024 06:40:20 -0700 Subject: [PATCH 258/376] IndexingMapAttr: print everything on one line in order to be able to match to variables in mlir tests. PiperOrigin-RevId: 657571942 --- .../gpu/fusions/mlir/ir/xla_gpu_attrs.cc | 116 +++++++++----- .../fusions/mlir/tests/indexing_map_attr.mlir | 141 +++++++++--------- 2 files changed, 144 insertions(+), 113 deletions(-) diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc index d3829056de5dc3..d38ed345a71aab 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/str_format.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/LogicalResult.h" #include "mlir/IR/AffineExpr.h" @@ -43,8 +44,8 @@ using mlir::AffineExpr; using mlir::ArrayRef; using mlir::AsmParser; using mlir::AsmPrinter; -using mlir::failed; using mlir::failure; +using mlir::success; ParseResult ParseInterval(AsmParser& parser, Interval& interval) { // ParseResult converts to `true` if parsing failed. @@ -54,60 +55,73 @@ ParseResult ParseInterval(AsmParser& parser, Interval& interval) { } void PrintDimVars(AsmPrinter& p, ArrayRef dim_vars) { - for (int i = 0; i < dim_vars.size(); ++i) { - p << "d" << i << " in " << dim_vars[i].bounds << "\n"; - } + int index = 0; + llvm::interleaveComma(dim_vars, p, [&](const DimVar& dim_var) { + p << "d" << index++ << " in " << dim_var.bounds; + }); } -mlir::FailureOr> ParseDimVars( - AsmParser& parser, ArrayRef dim_names) { - SmallVector dim_vars; - for (const auto& dim_name : dim_names) { +ParseResult ParseDimVars(AsmParser& parser, ArrayRef dim_names, + SmallVector& dim_vars) { + dim_vars.reserve(dim_names.size()); + for (const auto& [index, dim_name] : llvm::enumerate(dim_names)) { if (parser.parseKeyword(dim_name) || parser.parseKeyword("in") || ParseInterval(parser, dim_vars.emplace_back().bounds)) { return failure(); } + if (index < dim_names.size() - 1 && parser.parseComma()) { + return failure(); + } } - return dim_vars; + return success(); } void PrintRangeVars(AsmPrinter& p, ArrayRef range_vars) { - for (int i = 0; i < range_vars.size(); ++i) { - p << "s" << i << " in " << range_vars[i].range << "\n"; - } + int index = 0; + llvm::interleaveComma(range_vars, p, [&](const RangeVar& range_var) { + p << "s" << index++ << " in " << range_var.range; + }); } -mlir::FailureOr> ParseRangeVars( - AsmParser& parser, ArrayRef range_symbol_names) { - SmallVector range_vars; - for (const auto& range_symbol_name : range_symbol_names) { +ParseResult ParseRangeVars(AsmParser& parser, + ArrayRef range_symbol_names, + SmallVector& range_vars) { + range_vars.reserve(range_symbol_names.size()); + for (const auto& [index, range_symbol_name] : + llvm::enumerate(range_symbol_names)) { if (parser.parseKeyword(range_symbol_name) || parser.parseKeyword("in") || ParseInterval(parser, range_vars.emplace_back().range)) { return failure(); } + if (index < range_symbol_names.size() - 1 && parser.parseComma()) { + return failure(); + } } - return range_vars; + return success(); } void PrintConstraints(AsmPrinter& p, ArrayRef> constraints) { - for (const auto& [constrained_expression, range] : constraints) { - p << constrained_expression << " in " << range << "\n"; - } + llvm::interleaveComma(constraints, p, [&](const auto& constraint) { + p << constraint.first << " in " << constraint.second; + }); } -mlir::FailureOr>> ParseConstraints( +ParseResult ParseConstraints( AsmParser& parser, - ArrayRef> symbolSet) { - SmallVector> constraints; - while (failed(parser.parseOptionalGreater())) { + ArrayRef> symbolSet, + SmallVector>& constraints) { + // In order for there to be any constraints, there must be at least 1 symbol + // or dimension meaning there will be commas for as long as there are + // constraints left. + while (succeeded(parser.parseOptionalComma())) { auto& constraint = constraints.emplace_back(); if (parser.parseAffineExpr(symbolSet, constraint.first) || parser.parseKeyword("in") || ParseInterval(parser, constraint.second)) { return failure(); } } - return constraints; + return success(); } mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { @@ -131,35 +145,55 @@ mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { symbolSet.push_back( {symbol_strings[i], mlir::getAffineSymbolExpr(i, parser.getContext())}); } - - if (parser.parseKeyword("domain") || parser.parseColon()) { - return {}; + if (map.getNumDims() + map.getNumSymbols() > 0) { + if (parser.parseComma() || parser.parseKeyword("domain") || + parser.parseColon()) { + return {}; + } } - auto maybe_dim_vars = ParseDimVars(parser, dim_strings); - if (failed(maybe_dim_vars)) { - return {}; + + SmallVector dim_vars; + if (map.getNumDims() > 0) { + if (ParseDimVars(parser, dim_strings, dim_vars)) { + return {}; + } } - auto maybe_range_vars = ParseRangeVars(parser, symbol_strings); - if (failed(maybe_range_vars)) { - return {}; + SmallVector range_vars; + if (map.getNumSymbols() > 0) { + if (!dim_vars.empty() && parser.parseComma()) { + return {}; + } + if (ParseRangeVars(parser, symbol_strings, range_vars)) { + return {}; + } } - auto maybe_constraints = ParseConstraints(parser, symbolSet); - if (failed(maybe_constraints)) { + SmallVector> constraints; + if (ParseConstraints(parser, symbolSet, constraints) || + parser.parseGreater()) { return {}; } - // ParseConstraints consumes the > to know when to stop. - return IndexingMapAttr::get(parser.getContext(), map, *maybe_dim_vars, - *maybe_range_vars, *maybe_constraints); + return IndexingMapAttr::get(parser.getContext(), map, dim_vars, range_vars, + constraints); } void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { - printer << "<\n"; + printer << "<"; printer.printStrippedAttrOrType(getMap()); - printer << "\ndomain:\n"; + if (getDimVars().size() + getRangeVars().size() + getConstraints().size() > + 0) { + printer << ", domain: "; + } PrintDimVars(printer, getDimVars()); + if (!getDimVars().empty() && + getRangeVars().size() + getConstraints().size() > 0) { + printer << ", "; + } PrintRangeVars(printer, getRangeVars()); + if (!getRangeVars().empty() && !getConstraints().empty()) { + printer << ", "; + } PrintConstraints(printer, getConstraints()); printer << ">"; } diff --git a/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir b/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir index 3ea853dc8d0d19..c5cdeeb926d168 100644 --- a/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir +++ b/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir @@ -1,22 +1,22 @@ // RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt -split-input-file | FileCheck %s // CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0, d1, d2)[s0] -> (d0) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [1, 2] -// CHECK-NEXT: d1 in [5, 8] -// CHECK-NEXT: d2 in [10, 12] -// CHECK-NEXT: s0 in [0, 32] -// CHECK-NEXT: d0 mod 2 in [0, 1] -// CHECK-NEXT: d0 + s0 in [1, 10] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) +// CHECK-SAME: (d0, d1, d2)[s0] -> (d0), +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [1, 2], +// CHECK-SAME: d1 in [5, 8], +// CHECK-SAME: d2 in [10, 12], +// CHECK-SAME: s0 in [0, 32], +// CHECK-SAME: d0 mod 2 in [0, 1], +// CHECK-SAME: d0 + s0 in [1, 10] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0), domain: - d0 in [1, 2] - d1 in [5, 8] - d2 in [10, 12] - s0 in [0, 32] - d0 mod 2 in [0, 1] + d0 in [1, 2], + d1 in [5, 8], + d2 in [10, 12], + s0 in [0, 32], + d0 mod 2 in [0, 1], d0 + s0 in [1, 10] > @@ -27,26 +27,26 @@ func.func private @indexing_map_attr(tensor<32xf64, #map>) // ----- // CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [1, 2] -// CHECK-NEXT: d1 in [5, 8] -// CHECK-NEXT: s0 in [0, 10] -// CHECK-NEXT: s1 in [0, 5] -// CHECK-NEXT: s2 in [0, 32] -// CHECK-NEXT: d0 mod 2 in [0, 1] -// CHECK-NEXT: d0 + s0 in [1, 10] -// CHECK-NEXT: d1 + s1 + s2 in [1, 32] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) +// CHECK-SAME: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [1, 2] +// CHECK-SAME: d1 in [5, 8] +// CHECK-SAME: s0 in [0, 10] +// CHECK-SAME: s1 in [0, 5] +// CHECK-SAME: s2 in [0, 32] +// CHECK-SAME: d0 mod 2 in [0, 1] +// CHECK-SAME: d0 + s0 in [1, 10] +// CHECK-SAME: d1 + s1 + s2 in [1, 32] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2), domain: - d0 in [1, 2] - d1 in [5, 8] - s0 in [0, 10] - s1 in [0, 5] - s2 in [0, 32] - d0 mod 2 in [0, 1] - d0 + s0 in [1, 10] + d0 in [1, 2], + d1 in [5, 8], + s0 in [0, 10], + s1 in [0, 5], + s2 in [0, 32], + d0 mod 2 in [0, 1], + d0 + s0 in [1, 10], d1 + s1 + s2 in [1, 32] > func.func private @more_range_vars(tensor<32xf64, #map>) @@ -56,14 +56,14 @@ func.func private @more_range_vars(tensor<32xf64, #map>) // ----- // CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0)[s0] -> (d0) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [0, 100] -// CHECK-NEXT: s0 in [-3, -1] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0) +// CHECK-SAME: (d0)[s0] -> (d0) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [0, 100] +// CHECK-SAME: s0 in [-3, -1] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0), domain: - d0 in [0, 100] + d0 in [0, 100], s0 in [-3, -1] > func.func private @indexing_map_small(tensor<100xf64, #map>) @@ -73,18 +73,18 @@ func.func private @indexing_map_small(tensor<100xf64, #map>) // ----- // CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0, d1, d2)[s0] -> (d0) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [1, 2] -// CHECK-NEXT: d1 in [5, 8] -// CHECK-NEXT: d2 in [10, 12] -// CHECK-NEXT: s0 in [0, 32] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) +// CHECK-SAME: (d0, d1, d2)[s0] -> (d0) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [1, 2] +// CHECK-SAME: d1 in [5, 8] +// CHECK-SAME: d2 in [10, 12] +// CHECK-SAME: s0 in [0, 32] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0), domain: - d0 in [1, 2] - d1 in [5, 8] - d2 in [10, 12] + d0 in [1, 2], + d1 in [5, 8], + d2 in [10, 12], s0 in [0, 32] > func.func private @no_constraints(tensor<32xf64, #map>) @@ -94,14 +94,14 @@ func.func private @no_constraints(tensor<32xf64, #map>) // ----- // CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: ()[s0] -> (s0) -// CHECK-NEXT: domain: -// CHECK-NEXT: s0 in [3, 5] -// CHECK-NEXT: s0 mod 2 in [0, 1] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<()[s0] -> (s0) +// CHECK-SAME: ()[s0] -> (s0) +// CHECK-SAME: domain: +// CHECK-SAME: s0 in [3, 5] +// CHECK-SAME: s0 mod 2 in [0, 1] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<()[s0] -> (s0), domain: - s0 in [3, 5] + s0 in [3, 5], s0 mod 2 in [0, 1] > func.func private @no_dimensions(tensor<100xf64, #map>) @@ -111,14 +111,14 @@ func.func private @no_dimensions(tensor<100xf64, #map>) // ----- // CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0) -> (d0) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [3, 5] -// CHECK-NEXT: d0 mod 2 in [0, 1] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0) -> (d0) +// CHECK-SAME: (d0) -> (d0) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [3, 5] +// CHECK-SAME: d0 mod 2 in [0, 1] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0) -> (d0), domain: - d0 in [3, 5] + d0 in [3, 5], d0 mod 2 in [0, 1] > func.func private @no_symbols(tensor<100xf64, #map>) @@ -128,12 +128,9 @@ func.func private @no_symbols(tensor<100xf64, #map>) // ----- // CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: () -> () -// CHECK-NEXT: domain: -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<() -> () - domain: - > +// CHECK-SAME: () -> () +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<() -> ()> func.func private @empty(tensor<100xf64, #map>) // CHECK-LABEL: @empty // CHECK: tensor<100xf64, #[[$INDEX_MAP]]> \ No newline at end of file From 397c1f2bcda421937c088ccc6be09920200e75aa Mon Sep 17 00:00:00 2001 From: "Ryan M. Lefever" Date: Tue, 30 Jul 2024 07:44:05 -0700 Subject: [PATCH 259/376] I fixed a typo. PiperOrigin-RevId: 657588208 --- xla/service/memory_space_assignment/cost_analysis.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/service/memory_space_assignment/cost_analysis.h b/xla/service/memory_space_assignment/cost_analysis.h index 364027c79e760a..72229fcab2d273 100644 --- a/xla/service/memory_space_assignment/cost_analysis.h +++ b/xla/service/memory_space_assignment/cost_analysis.h @@ -89,7 +89,7 @@ class BaseCosts { // The bandwidth of copies to/from alternate memory. virtual float BytesPerSecond() = 0; - // The compute cost of instruction. The compute cost assumes 0 memory transer + // The compute cost of instruction. The compute cost assumes 0 memory transfer // is required. virtual float ComputeSeconds(const HloInstruction& instruction) = 0; From 02fdea873d96c56957750b8bba150f1b30a9d0bc Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 30 Jul 2024 08:05:24 -0700 Subject: [PATCH 260/376] [XLA:GPU] Fix layout normalization for clamp and select The operations allow for limited broadcasting of scalars. PiperOrigin-RevId: 657594540 --- xla/service/layout_normalization.cc | 32 ++++++++---- xla/service/layout_normalization_test.cc | 66 +++++++++++++++++++++--- 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/xla/service/layout_normalization.cc b/xla/service/layout_normalization.cc index 2dce620c81b267..16781509e22c60 100644 --- a/xla/service/layout_normalization.cc +++ b/xla/service/layout_normalization.cc @@ -742,21 +742,31 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { Shape s = hlo->shape(); HloOpcode opcode = hlo->opcode(); TF_RET_CHECK(opcode == HloOpcode::kClamp || opcode == HloOpcode::kSelect); - HloInstruction* p = hlo->mutable_operand(0); - HloInstruction* i1 = hlo->mutable_operand(1); - HloInstruction* i2 = hlo->mutable_operand(2); - TF_RET_CHECK(p->shape().layout() == s.layout()); - TF_RET_CHECK(i1->shape().layout() == s.layout()); - TF_RET_CHECK(i2->shape().layout() == s.layout()); + HloInstruction* arg0 = hlo->mutable_operand(0); + HloInstruction* arg1 = hlo->mutable_operand(1); + HloInstruction* arg2 = hlo->mutable_operand(2); + if (opcode == HloOpcode::kClamp) { + TF_RET_CHECK(arg1->shape().layout() == s.layout()); + } else if (opcode == HloOpcode::kSelect) { + TF_RET_CHECK(arg1->shape().layout() == s.layout()); + TF_RET_CHECK(arg2->shape().layout() == s.layout()); + } else { + TF_RET_CHECK(false); + } - TF_ASSIGN_OR_RETURN(HloInstruction * p_0, GetNormalizedInput(p)); - TF_ASSIGN_OR_RETURN(HloInstruction * i1_0, GetNormalizedInput(i1)); - TF_ASSIGN_OR_RETURN(HloInstruction * i2_0, GetNormalizedInput(i2)); + TF_ASSIGN_OR_RETURN(HloInstruction * normalized_arg0, + GetNormalizedInput(arg0)); + TF_ASSIGN_OR_RETURN(HloInstruction * normalized_arg1, + GetNormalizedInput(arg1)); + TF_ASSIGN_OR_RETURN(HloInstruction * normalized_arg2, + GetNormalizedInput(arg2)); TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferTernaryOpShape( - opcode, p_0, i1_0, i2_0)); + opcode, normalized_arg0, + normalized_arg1, normalized_arg2)); HloInstruction* normalized = hlo->parent()->AddInstruction( - HloInstruction::CreateTernary(new_shape, opcode, p_0, i1_0, i2_0)); + HloInstruction::CreateTernary(new_shape, opcode, normalized_arg0, + normalized_arg1, normalized_arg2)); hlo->SetupDerivedInstruction(normalized); SetVisited(*normalized); diff --git a/xla/service/layout_normalization_test.cc b/xla/service/layout_normalization_test.cc index d2b9d92d2fb934..88ea4828ec597a 100644 --- a/xla/service/layout_normalization_test.cc +++ b/xla/service/layout_normalization_test.cc @@ -644,10 +644,26 @@ TEST_F(LayoutNormalizationTest, Select) { HloModule module ENTRY main { - p0 = f32[1,17,9,9]{1,3,2,0} parameter(0) - p1 = f32[1,17,9,9]{1,3,2,0} parameter(1) - b = pred[1,17,9,9]{1,3,2,0} parameter(2) - ROOT out = f32[1,17,9,9]{1,3,2,0} select(b, p0, p1), metadata={op_name="test"} + lhs = f32[1,17,9,9]{1,3,2,0} parameter(0) + rhs = f32[1,17,9,9]{1,3,2,0} parameter(1) + p = pred[1,17,9,9]{1,3,2,0} parameter(2) + ROOT out = f32[1,17,9,9]{1,3,2,0} select(p, lhs, rhs), metadata={op_name="test"} +} +)"; + CheckLayoutNormalization(hlo, R"( +// CHECK: f32[1,9,9,17]{3,2,1,0} select({{.*}}, {{.*}}, {{.*}}), metadata={op_name="test"} +)"); +} + +TEST_F(LayoutNormalizationTest, SelectScalarPredicate) { + const char* hlo = R"( +HloModule module + +ENTRY main { + lhs = f32[1,17,9,9]{1,3,2,0} parameter(0) + rhs = f32[1,17,9,9]{1,3,2,0} parameter(1) + p = pred[] parameter(2) + ROOT out = f32[1,17,9,9]{1,3,2,0} select(p, lhs, rhs), metadata={op_name="test"} } )"; CheckLayoutNormalization(hlo, R"( @@ -734,10 +750,44 @@ TEST_F(LayoutNormalizationTest, Clamp) { HloModule m ENTRY main { - p0 = f32[64,1,32]{1,0,2} parameter(0) - p1 = f32[64,1,32]{1,0,2} parameter(1) - p2 = f32[64,1,32]{1,0,2} parameter(2) - ROOT out = f32[64,1,32]{1,0,2} clamp(f32[64,1,32]{1,0,2} p0, f32[64,1,32]{1,0,2} p1, f32[64,1,32]{1,0,2} p2), metadata={op_name="test"} + lb = f32[64,1,32]{1,0,2} parameter(0) + in = f32[64,1,32]{1,0,2} parameter(1) + ub = f32[64,1,32]{1,0,2} parameter(2) + ROOT out = f32[64,1,32]{1,0,2} clamp(f32[64,1,32]{1,0,2} lb, f32[64,1,32]{1,0,2} in, f32[64,1,32]{1,0,2} ub), metadata={op_name="test"} +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: f32[32,64,1]{2,1,0} clamp({{.*}}, {{.*}}, {{.*}}), metadata={op_name="test"} +)"); +} + +TEST_F(LayoutNormalizationTest, ClampScalarBounds) { + const char* hlo = R"( +HloModule m + +ENTRY main { + lb = f32[] parameter(0) + in = f32[64,1,32]{1,0,2} parameter(1) + ub = f32[] parameter(2) + ROOT out = f32[64,1,32]{1,0,2} clamp(f32[] lb, f32[64,1,32]{1,0,2} in, f32[] ub), metadata={op_name="test"} +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: f32[32,64,1]{2,1,0} clamp({{.*}}, {{.*}}, {{.*}}), metadata={op_name="test"} +)"); +} + +TEST_F(LayoutNormalizationTest, ClampScalarLb) { + const char* hlo = R"( +HloModule m + +ENTRY main { + lb = f32[] parameter(0) + in = f32[64,1,32]{1,0,2} parameter(1) + ub = f32[64,1,32]{1,0,2} parameter(2) + ROOT out = f32[64,1,32]{1,0,2} clamp(f32[] lb, f32[64,1,32]{1,0,2} in, f32[64,1,32]{1,0,2} ub), metadata={op_name="test"} } )"; From c20a0ad41a4365a1fef26a2ae89c21a7dfc37724 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 09:04:01 -0700 Subject: [PATCH 261/376] Use the logical way of the if statement in preprocessing PiperOrigin-RevId: 657612949 --- third_party/triton/temporary/cuda11-temporary.patch | 11 +++++------ third_party/triton/temporary/series.bzl | 1 + 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/triton/temporary/cuda11-temporary.patch b/third_party/triton/temporary/cuda11-temporary.patch index 2c97b606a0ff7a..a92166eef6df71 100644 --- a/third_party/triton/temporary/cuda11-temporary.patch +++ b/third_party/triton/temporary/cuda11-temporary.patch @@ -2,10 +2,9 @@ # patches. It is only here temporarily to be included in the openxla version, # but it will be removed during the next triton integration. -diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backend/cuda_utils.cc ---- a/third_party/nvidia/backend/cuda_utils.cc -+++ b/third_party/nvidia/backend/cuda_utils.cc -@@ -587,6 +587,8 @@ static PyObject *loadBinary(PyObject *se +--- a/third_party/nvidia/backend/driver.c ++++ b/third_party/nvidia/backend/driver.c +@@ -154,6 +154,8 @@ static PyObject *loadBinary(PyObject *se typedef CUresult (*cuOccupancyMaxActiveClusters_t)( int *numClusters, CUfunction func, const CUlaunchConfig *config); @@ -14,7 +13,7 @@ diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backe typedef CUresult (*cuTensorMapEncodeTiled_t)( CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, -@@ -594,6 +596,7 @@ typedef CUresult (*cuTensorMapEncodeTile +@@ -161,6 +161,7 @@ typedef CUresult (*cuTensorMapEncodeTile const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); @@ -22,7 +21,7 @@ diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backe #define defineGetFunctionHandle(name, symbolName) \ static symbolName##_t name() { \ -@@ -620,8 +623,11 @@ typedef CUresult (*cuTensorMapEncodeTile +@@ -187,8 +187,11 @@ typedef CUresult (*cuTensorMapEncodeTile defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, cuOccupancyMaxActiveClusters); diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 4fa55269e3323c..9d26b42a567757 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -14,5 +14,6 @@ those to this list. """ temporary_patch_list = [ + "//third_party/triton/temporary:cuda11-temporary.patch", # Add new patches just above this line ] From d20fd49a63549fd6ec75b82b7e6f9acfb82857d3 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 30 Jul 2024 09:25:34 -0700 Subject: [PATCH 262/376] [xla:cpu] Optimize + fix a bug in WhileThunk Error checking must be after !IsAvailable check, or otherwise we can miss an error that arrives between two checks. Also replace CHECK with DCHECK on a hot path in ThunkExecutor. PiperOrigin-RevId: 657619625 --- xla/service/cpu/runtime/BUILD | 1 + xla/service/cpu/runtime/thunk_executor.cc | 4 +- .../cpu/runtime/thunk_executor_test.cc | 7 +-- xla/service/cpu/runtime/while_thunk.cc | 50 +++++++++++-------- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/xla/service/cpu/runtime/BUILD b/xla/service/cpu/runtime/BUILD index f34570a81b7517..7065dc4a74bdc6 100644 --- a/xla/service/cpu/runtime/BUILD +++ b/xla/service/cpu/runtime/BUILD @@ -923,6 +923,7 @@ cc_library( srcs = ["while_thunk.cc"], hdrs = ["while_thunk.h"], deps = [ + ":buffer_allocations", ":thunk", ":thunk_executor", "//xla/runtime:buffer_use", diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index 4281442d5c4305..4173f687c13eb5 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -349,7 +349,7 @@ void ThunkExecutor::ProcessOutEdges( ExecuteState::Node& out_node = state->nodes[out_edge]; int64_t cnt = out_node.counter.fetch_sub(1, std::memory_order_release); - CHECK_GE(cnt, 1) << "Node counter can't drop below 0"; // Crash Ok + DCHECK_GE(cnt, 1) << "Node counter can't drop below 0"; if (cnt == 1) ready_queue.push_back(out_edge); } @@ -367,7 +367,7 @@ void ThunkExecutor::ProcessOutEdges( if (ABSL_PREDICT_FALSE(state->abort.load(std::memory_order_relaxed))) { auto take_error = [&] { absl::MutexLock lock(&state->abort_mutex); - CHECK(!state->abort_status.ok()) // Crash Ok + DCHECK(!state->abort_status.ok()) << "Abort status must be set if execution is aborted"; return std::move(state->abort_status); }; diff --git a/xla/service/cpu/runtime/thunk_executor_test.cc b/xla/service/cpu/runtime/thunk_executor_test.cc index 697a47d83a4e8d..26ba2e553a2b4d 100644 --- a/xla/service/cpu/runtime/thunk_executor_test.cc +++ b/xla/service/cpu/runtime/thunk_executor_test.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/service/cpu/runtime/thunk_executor.h" -#define EIGEN_USE_THREADS - #include #include #include @@ -31,7 +29,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "unsupported/Eigen/CXX11/Tensor" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/buffer_allocations.h" @@ -48,6 +45,10 @@ limitations under the License. #include "tsl/platform/test_benchmark.h" #include "tsl/platform/threadpool.h" +#define EIGEN_USE_THREADS + +#include "unsupported/Eigen/CXX11/Tensor" + namespace xla::cpu { namespace { diff --git a/xla/service/cpu/runtime/while_thunk.cc b/xla/service/cpu/runtime/while_thunk.cc index 4e326b63a91706..a5aa14a419ac4c 100644 --- a/xla/service/cpu/runtime/while_thunk.cc +++ b/xla/service/cpu/runtime/while_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/runtime/buffer_allocations.h" #include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/stream_executor/device_memory.h" @@ -75,12 +76,6 @@ tsl::AsyncValueRef WhileThunk::ExecuteAsync( return cond_executor_.Execute(params); }); - // Immediately forward error to the caller. - if (ABSL_PREDICT_FALSE(cond_event.IsError())) { - event.SetError(cond_event.GetError()); - return; - } - // If we don't know yet wether we should execute the next iteration or // not, attach `AndThen` continuation to the `cond_event`. if (!cond_event.IsAvailable()) { @@ -89,9 +84,15 @@ tsl::AsyncValueRef WhileThunk::ExecuteAsync( return; } + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(cond_event.IsError())) { + event.SetError(cond_event.GetError()); + return; + } + // At this point `*condition` should have been updated and we may continue // executing the while loop in the current thread. - DCHECK(cond_event.IsAvailable()); + DCHECK(cond_event.IsConcrete()); } // Successfully completed while loop iterations. @@ -111,46 +112,53 @@ tsl::AsyncValueRef WhileThunk::Execute( const ExecuteParams& params) { tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase cond_data, - params.buffer_allocations->GetDeviceAddress(cond_buffer_)); + const BufferAllocations* allocations = params.buffer_allocations; + + se::DeviceMemoryBase cond_data; + if (ShouldCheckBufferSlices()) { + TF_ASSIGN_OR_RETURN(cond_data, allocations->GetDeviceAddress(cond_buffer_)); + } else { + cond_data = allocations->GetDeviceAddressUnchecked(cond_buffer_); + } bool* condition = reinterpret_cast(cond_data.opaque()); // Execute `cond` thunk sequence to initialize the loop condition. auto init_event = cond_executor_.Execute(params); - // Immediately forward error to the caller. - if (ABSL_PREDICT_FALSE(init_event.IsError())) { - return init_event.GetError(); - } - // If we don't know if we should continue or not, switch to async execution // mode using `init_event` as a dependency. if (ABSL_PREDICT_FALSE(!init_event.IsAvailable())) { return ExecuteAsync(params, std::move(init_event), condition); } + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(init_event.IsError())) { + return init_event.GetError(); + } + + DCHECK(init_event.IsConcrete()); + while (*condition) { auto body_event = body_executor_.Execute(params); auto cond_event = body_event.FlatMap([this, ¶ms](ExecuteEvent) { return cond_executor_.Execute(params); }); - // Immediately forward error to the caller. - if (ABSL_PREDICT_FALSE(cond_event.IsError())) { - return cond_event.GetError(); - } - // If we don't know if we should continue or not, switch to async execution // mode using `cond_event` as a dependency. if (ABSL_PREDICT_FALSE(!cond_event.IsAvailable())) { return ExecuteAsync(params, std::move(cond_event), condition); } + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(cond_event.IsError())) { + return cond_event.GetError(); + } + // At this point `*condition` should have been updated and we may continue // executing the while loop in the current thread. - DCHECK(cond_event.IsAvailable()); + DCHECK(cond_event.IsConcrete()); } // Successfully completed while loop iterations. From f362fcdd0de238904cd5861fcfd7dd1866b71893 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Bana=C5=9B?= Date: Tue, 30 Jul 2024 09:44:50 -0700 Subject: [PATCH 263/376] [XLA:CPU] Align getting device ordinal with current runtime. PiperOrigin-RevId: 657626077 --- xla/service/cpu/BUILD | 1 + xla/service/cpu/cpu_executable.cc | 3 +-- xla/service/cpu/cpu_runtime.cc | 22 +++++++++++----------- xla/service/cpu/cpu_runtime.h | 2 ++ xla/service/cpu/cpu_runtime_test.cc | 25 +++++++++++++++++++++++++ 5 files changed, 40 insertions(+), 13 deletions(-) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index cb2bf71ad8b440..5b0b4e2d2cd7a5 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1359,6 +1359,7 @@ xla_cc_test( ":runtime_matmul_acl", ":runtime_single_threaded_matmul", "//xla:array2d", + "//xla:executable_run_options", "//xla:types", "//xla/client:local_client", "//xla/service:custom_call_status_internal", diff --git a/xla/service/cpu/cpu_executable.cc b/xla/service/cpu/cpu_executable.cc index a37a40e9d19acd..816ad133bf9a25 100644 --- a/xla/service/cpu/cpu_executable.cc +++ b/xla/service/cpu/cpu_executable.cc @@ -387,8 +387,7 @@ absl::Status CpuExecutable::ExecuteThunks( Thunk::ExecuteParams execute_params = { &*function_registry_, &allocations, - runtime::GetXfeedManager( - run_options->stream()->parent()->device_ordinal()), + runtime::GetXfeedManager(runtime::GetDeviceOrdinal(run_options)), run_options->intra_op_thread_pool(), &task_runner, &collective_execute_params, diff --git a/xla/service/cpu/cpu_runtime.cc b/xla/service/cpu/cpu_runtime.cc index f3ac32c04904bc..4e209e61f283c6 100644 --- a/xla/service/cpu/cpu_runtime.cc +++ b/xla/service/cpu/cpu_runtime.cc @@ -73,6 +73,17 @@ XfeedManager* GetXfeedManager(int device_ordinal) { return it->second; } +// TODO(zhangqiaorjc): Prefer to make callers set and use device_ordinal +// directly since callers may not have a Stream*. +int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) { + if (!run_options) { + return 0; + } else if (run_options->device_ordinal() != -1) { + return run_options->device_ordinal(); + } + return run_options->stream()->parent()->device_ordinal(); +} + extern const char* const kEigenMatMulF16SymbolName = "__xla_cpu_runtime_EigenMatMulF16"; extern const char* const kEigenMatMulF32SymbolName = @@ -198,17 +209,6 @@ std::string ShapeString(const void* shape_ptr, int32_t shape_length) { return ""; } -// TODO(zhangqiaorjc): Prefer to make callers set and use device_ordinal -// directly since callers may not have a Stream*. -int GetDeviceOrdinal(const ExecutableRunOptions* run_options) { - if (!run_options) { - return 0; - } else if (run_options->device_ordinal() != -1) { - return run_options->device_ordinal(); - } - return run_options->stream()->parent()->device_ordinal(); -} - ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void* AcquireInfeedBufferForDequeueImpl(const ExecutableRunOptions* run_options, int32_t buffer_length, diff --git a/xla/service/cpu/cpu_runtime.h b/xla/service/cpu/cpu_runtime.h index c40a84caf8aced..92beff43a3c0ea 100644 --- a/xla/service/cpu/cpu_runtime.h +++ b/xla/service/cpu/cpu_runtime.h @@ -103,6 +103,8 @@ extern const char* const kXlaCpuRuntimeSymbolNamePrefix; // `device_ordinal`. Note the device ordinal does not name a CPU XfeedManager* GetXfeedManager(int device_ordinal); +int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options); + } // namespace runtime } // namespace cpu } // namespace xla diff --git a/xla/service/cpu/cpu_runtime_test.cc b/xla/service/cpu/cpu_runtime_test.cc index 78bbc8f661e311..4e4d6aa5a909b2 100644 --- a/xla/service/cpu/cpu_runtime_test.cc +++ b/xla/service/cpu/cpu_runtime_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "xla/array2d.h" #include "xla/client/local_client.h" +#include "xla/executable_run_options.h" #include "xla/service/cpu/runtime_custom_call_status.h" #include "xla/service/cpu/runtime_matmul.h" #include "xla/service/cpu/runtime_matmul_acl.h" @@ -180,5 +181,29 @@ TEST_F(CpuRuntimeTest, FailureStatus) { ASSERT_FALSE(__xla_cpu_runtime_StatusIsSuccess(&success_status)); } +// When run_options is null, the process should not crash and the device ordinal +// should be 0. +TEST_F(CpuRuntimeTest, GetDeviceOrdinalWhenRunOptionsEmpty) { + EXPECT_EQ(cpu::runtime::GetDeviceOrdinal(/*run_options=*/nullptr), 0); +} + +// When the device ordinal is set directly in run options, it should be returned +// (and NOT the value from stream). +TEST_F(CpuRuntimeTest, GetDeviceOrdinalWhenSetInRunOptions) { + // GetDeviceOrdinal implementation bases on the fact that device ordinal is + // -1 by default. So we need to assert for that here to avoid crash in case + // the default value changes in the future. + ExecutableRunOptions run_options; + ASSERT_EQ(run_options.device_ordinal(), -1); + + // Actual test - set device ordinal in run options and check that it is + // returned. + run_options.set_device_ordinal(3); + EXPECT_EQ(cpu::runtime::GetDeviceOrdinal(&run_options), 3); +} + +// TODO(abanas): Add test case for the device ordinal with stream case. It +// requires mocking the stream and stream executor. + } // namespace } // namespace xla From 77aa59bf471e547956e14fe7788434dbc4e091ca Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 30 Jul 2024 10:06:00 -0700 Subject: [PATCH 264/376] [XLA:GPU] Enable tests. PiperOrigin-RevId: 657633341 --- xla/tests/BUILD | 6 ------ 1 file changed, 6 deletions(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 2c225af4b8dffe..2596ac963e29fc 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2272,15 +2272,9 @@ xla_test( srcs = ["collective_ops_test.cc"], args = ["--xla_force_host_platform_device_count=4"], backend_tags = { - # This test is tagged "manual" because it requires multiple GPUs, and - # Forge only supports single-GPU tests. Guitar skips "manual" tests - # unless they're also tagged "guitar". "gpu": [ - "guitar", - "manual", "multi_gpu", "no_oss", - "notap", ], "cpu": [ "notsan", From 028e492832022ee779da5ea5f160814f34118cbd Mon Sep 17 00:00:00 2001 From: Chao Date: Tue, 30 Jul 2024 10:09:20 -0700 Subject: [PATCH 265/376] PR #15477: [ROCm] hot fix rocm build due to triton update LDS pass Imported from GitHub PR https://github.com/openxla/xla/pull/15477 This build break is introduced by https://github.com/openxla/xla/pull/15257 and ROcm has a new optimized LDS pass on openai/triton https://github.com/triton-lang/triton/pull/3730 @xla-rotation Copybara import of the project: -- 6f86fdbd090a4fc3fa2346ba6969d7ddeae773e3 by Chao Chen : updated rocm triton OptimizeLDSUsage pass due to https://github.com/triton-lang/triton/pull/3730 Merging this change closes #15477 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15477 from ROCm:ci_hotfix_20240730 6f86fdbd090a4fc3fa2346ba6969d7ddeae773e3 PiperOrigin-RevId: 657634867 --- xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index 2c12aafb9ac536..e31e29b0e0d29c 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -107,7 +107,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( ccRocm.gfx_version())); const int custom_lds_size = 0; - pm.addPass(mlir::triton::AMD::createOptimizeLdsUsagePass(ccRocm.gfx_version(), + pm.addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(ccRocm.gfx_version(), custom_lds_size)); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); From 4caa88066d8927632f8593dd3c6dcc11ecc704ec Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 10:16:38 -0700 Subject: [PATCH 266/376] Update Triton AMD to use the non-deprecated MCStreamer constructor. PiperOrigin-RevId: 657638017 --- .../triton/llvm_integration/cl657620552.patch | 18 ++++++++++++++++++ third_party/triton/llvm_integration/series.bzl | 1 + 2 files changed, 19 insertions(+) create mode 100644 third_party/triton/llvm_integration/cl657620552.patch diff --git a/third_party/triton/llvm_integration/cl657620552.patch b/third_party/triton/llvm_integration/cl657620552.patch new file mode 100644 index 00000000000000..4a1f47d79e6c92 --- /dev/null +++ b/third_party/triton/llvm_integration/cl657620552.patch @@ -0,0 +1,18 @@ +# Do not upstream this patch. This has been already upstreamed in +# https://github.com/triton-lang/triton/commit/de46a0ede6efe7e93c2a9ebef639e36c6177c511 +# Next integration will include it and this patch should be removed then. + +diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc +--- a/third_party/amd/python/triton_amd.cc ++++ b/third_party/amd/python/triton_amd.cc +@@ -193,9 +193,7 @@ void init_triton_amd(py::module &&m) { + target->createMCAsmBackend(*sti, *mri, mcOptions)); + mcStreamer.reset(target->createMCObjectStreamer( + triple, ctx, std::move(mab), mab->createObjectWriter(svos), +- std::move(ce), *sti, mcOptions.MCRelaxAll, +- mcOptions.MCIncrementalLinkerCompatible, +- /*DWARFMustBeAtTheEnd=*/false)); ++ std::move(ce), *sti)); + + std::unique_ptr parser( + createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai)); diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 656b9c894904d8..5348e66b34413c 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,5 +8,6 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ + "//third_party/triton/llvm_integration:cl657620552.patch", # Add new patches just above this line ] From 03659f0335ffdb1a15e065e1051fb86c63873dad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 10:21:34 -0700 Subject: [PATCH 267/376] [XLA] Moved existing fuzzy matcher out of flash_attention to eventually use it more widely. Also added a test that it ignore converts PiperOrigin-RevId: 657640691 --- xla/service/BUILD | 23 +++++++ xla/service/fuzzy_matcher.h | 109 ++++++++++++++++++++++++++++++ xla/service/fuzzy_matcher_test.cc | 47 +++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 xla/service/fuzzy_matcher.h create mode 100644 xla/service/fuzzy_matcher_test.cc diff --git a/xla/service/BUILD b/xla/service/BUILD index 6765f1c1ff9ae4..21f3e28807fd74 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1029,6 +1029,29 @@ xla_cc_test( ], ) +cc_library( + name = "fuzzy_matcher", + hdrs = ["fuzzy_matcher.h"], + deps = [ + ":pattern_matcher", + "//xla/hlo/ir:hlo", + ], +) + +xla_cc_test( + name = "fuzzy_matcher_test", + srcs = ["fuzzy_matcher_test.cc"], + deps = [ + ":fuzzy_matcher", + ":pattern_matcher", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + ], +) + xla_cc_test( name = "hlo_dfs_reachability_test", srcs = ["hlo_dfs_reachability_test.cc"], diff --git a/xla/service/fuzzy_matcher.h b/xla/service/fuzzy_matcher.h new file mode 100644 index 00000000000000..6e5cd3e09eee5e --- /dev/null +++ b/xla/service/fuzzy_matcher.h @@ -0,0 +1,109 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_FUZZY_MATCHER_H_ +#define XLA_SERVICE_FUZZY_MATCHER_H_ + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/pattern_matcher.h" + +namespace xla { + +// Fuzzy matchers for HLOs. +namespace fm { + +// TODO(b/355972677): Extend this to support opcodes other than convert +template +auto OptConvert(Pattern pattern) { + auto shared = match::SharedSubpattern(pattern); + return match::AnyOf(match::Convert(shared), shared); +} + +#define XLA_FUZZY_UNOP_PATTERN(NAME) \ + template \ + inline auto NAME(HloInstructionType** matched_inst) { \ + return OptConvert(match::Op(matched_inst).WithOpcode(HloOpcode::k##NAME)); \ + } \ + \ + template \ + inline auto NAME(Arg&& arg) { \ + return OptConvert(match::Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) { \ + return OptConvert(match::Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))); \ + } +XLA_FUZZY_UNOP_PATTERN(Tanh) +XLA_FUZZY_UNOP_PATTERN(Exp) +XLA_FUZZY_UNOP_PATTERN(Broadcast) +#undef XLA_FUZZY_UNOP_PATTERN + +#define XLA_FUZZY_BINOP_PATTERN(NAME) \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \ + return OptConvert(match::Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))); \ + } \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \ + return OptConvert(match::Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))); \ + } +XLA_FUZZY_BINOP_PATTERN(Dot) +XLA_FUZZY_BINOP_PATTERN(Divide) +XLA_FUZZY_BINOP_PATTERN(Subtract) +XLA_FUZZY_BINOP_PATTERN(Multiply) +// Currently we only use binary matcher for reduce. +XLA_FUZZY_BINOP_PATTERN(Reduce) +#undef XLA_FUZZY_BINOP_PATTERN + +#define XLA_FUZZY_TERNOP_PATTERN(NAME) \ + template \ + inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { \ + return OptConvert(match::Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \ + Arg1&& arg1, Arg2&& arg2) { \ + return OptConvert(match::Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))); \ + } +XLA_FUZZY_TERNOP_PATTERN(Select); +#undef XLA_FUZZY_TERNOP_PATTERN + +} // namespace fm + +} // namespace xla + +#endif // XLA_SERVICE_FUZZY_MATCHER_H_ diff --git a/xla/service/fuzzy_matcher_test.cc b/xla/service/fuzzy_matcher_test.cc new file mode 100644 index 00000000000000..ac97d13233aa52 --- /dev/null +++ b/xla/service/fuzzy_matcher_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/fuzzy_matcher.h" + +#include +#include "xla/service/pattern_matcher.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +using FuzzyMatcherTest = HloTestBase; + +TEST_F(FuzzyMatcherTest, IgnoreConvert) { + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + x = f16[8,3] parameter(0) + y = f16[8,3] parameter(1) + div = f16[8,3] divide(x, y) + ROOT convert = f32[8,3] convert(div) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_TRUE( + Match(root, fm::Divide(match::Parameter(0), match::Parameter(1)))); +} + +} // namespace + +} // namespace xla From 9ab1608104eb002c665f3d83f2af293c89df4d1f Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 30 Jul 2024 10:23:42 -0700 Subject: [PATCH 268/376] Move `status_test_util.h` to XLA PiperOrigin-RevId: 657642030 --- third_party/tsl/tsl/lib/core/BUILD | 3 ++ .../tsl/tsl/lib/core/status_test_util.h | 14 +------ xla/tsl/lib/core/BUILD | 41 +++++++++++++++++++ xla/tsl/lib/core/status_test_util.h | 33 +++++++++++++++ 4 files changed, 78 insertions(+), 13 deletions(-) create mode 100644 xla/tsl/lib/core/BUILD create mode 100644 xla/tsl/lib/core/status_test_util.h diff --git a/third_party/tsl/tsl/lib/core/BUILD b/third_party/tsl/tsl/lib/core/BUILD index a1227c8351de78..fd8be865417735 100644 --- a/third_party/tsl/tsl/lib/core/BUILD +++ b/third_party/tsl/tsl/lib/core/BUILD @@ -37,6 +37,7 @@ filegroup( "bitmap.h", "bits.h", "status_test_util.h", + "@xla//xla/tsl/lib/core:legacy_lib_core_status_test_util_header", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -68,6 +69,7 @@ filegroup( name = "legacy_lib_core_status_test_util_header", srcs = [ "status_test_util.h", + "@xla//xla/tsl/lib/core:legacy_lib_core_status_test_util_header", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -95,6 +97,7 @@ cc_library( deps = [ "//tsl/platform:status_matchers", "//tsl/platform:test", + "@xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/third_party/tsl/tsl/lib/core/status_test_util.h b/third_party/tsl/tsl/lib/core/status_test_util.h index 56644ba71773c4..a15aa79a181ad8 100644 --- a/third_party/tsl/tsl/lib/core/status_test_util.h +++ b/third_party/tsl/tsl/lib/core/status_test_util.h @@ -16,18 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ #define TENSORFLOW_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/test.h" - -// Macros for testing the results of functions that return tensorflow::Status. -#define TF_EXPECT_OK(statement) EXPECT_THAT((statement), ::tsl::testing::IsOk()) -#define TF_ASSERT_OK(statement) ASSERT_THAT((statement), ::tsl::testing::IsOk()) - -// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not -// provide much value (when they fail, they would just print the OK status -// which conveys no more information than EXPECT_FALSE(status.ok()); -// If you want to check for particular errors, a better alternative is with -// status matchers: -// EXPECT_THAT(s, tensorflow::testing::StatusIs(status.code(), "message")); +#include "xla/tsl/lib/core/status_test_util.h" #endif // TENSORFLOW_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ diff --git a/xla/tsl/lib/core/BUILD b/xla/tsl/lib/core/BUILD new file mode 100644 index 00000000000000..11e199f9889be6 --- /dev/null +++ b/xla/tsl/lib/core/BUILD @@ -0,0 +1,41 @@ +# Description: +# Tensor Standard Libraries. +# +# The libraries in this package are not allowed to have ANY dependencies +# to other TF components outside of TSL. + +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") + +# TODO(rdzhabarov): Tighten visibility after migration is complete. +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +filegroup( + name = "legacy_lib_core_status_test_util_header", + srcs = [ + "status_test_util.h", + ], + compatible_with = get_compatible_with_portable(), + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "@tsl//tsl/lib/core:__pkg__", + "//tensorflow/core/lib/core:__pkg__", + ]), +) + +cc_library( + name = "status_test_util", + testonly = 1, + hdrs = ["status_test_util.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:test", + ], +) diff --git a/xla/tsl/lib/core/status_test_util.h b/xla/tsl/lib/core/status_test_util.h new file mode 100644 index 00000000000000..0c8f5d9d50e4ea --- /dev/null +++ b/xla/tsl/lib/core/status_test_util.h @@ -0,0 +1,33 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ +#define XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ + +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +// Macros for testing the results of functions that return tensorflow::Status. +#define TF_EXPECT_OK(statement) EXPECT_THAT((statement), ::tsl::testing::IsOk()) +#define TF_ASSERT_OK(statement) ASSERT_THAT((statement), ::tsl::testing::IsOk()) + +// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not +// provide much value (when they fail, they would just print the OK status +// which conveys no more information than EXPECT_FALSE(status.ok()); +// If you want to check for particular errors, a better alternative is with +// status matchers: +// EXPECT_THAT(s, tensorflow::testing::StatusIs(status.code(), "message")); + +#endif // XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ From 20c5548748822bf6c3cb2ca0384673f9598efd17 Mon Sep 17 00:00:00 2001 From: akhilgoe <114951738+akhilgoe@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:24:07 -0700 Subject: [PATCH 269/376] PR #15455: [XLA:CPU][oneDNN] Fix typos in oneDNN layer norm test file Imported from GitHub PR https://github.com/openxla/xla/pull/15455 This PR addresses the typos in one of the tests of the ```onednn_layer_norm``` test file that were causing it to fail. Copybara import of the project: -- 6365054363cf11c3cb81c0573ab7c03256162f17 by Akhil Goel : Fix typos in test file Merging this change closes #15455 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15455 from Intel-tensorflow:akhil/fix_ln_test 6365054363cf11c3cb81c0573ab7c03256162f17 PiperOrigin-RevId: 657642262 --- xla/service/cpu/tests/onednn_layer_norm_test.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xla/service/cpu/tests/onednn_layer_norm_test.cc b/xla/service/cpu/tests/onednn_layer_norm_test.cc index 9751e207b5e5da..39913542d2d0ee 100644 --- a/xla/service/cpu/tests/onednn_layer_norm_test.cc +++ b/xla/service/cpu/tests/onednn_layer_norm_test.cc @@ -219,7 +219,7 @@ TEST_F(LayerNormTest, LayerNormTest2_F16) { ROOT add_0 = f32[] add(Arg_0, Arg_1) } ENTRY main { - Arg_2= f16[2,4,8] parameter(0), sharding={replicated} + Arg_2 = f16[2,4,8] parameter(0), sharding={replicated} convert_0 = f32[2,4,8] convert(Arg_2) constant_0 = f32[] constant(0) convert_1 = f32[] convert(constant_0) @@ -241,7 +241,7 @@ TEST_F(LayerNormTest, LayerNormTest2_F16) { constant_3 = s32[] constant(8) convert_6 = f32[] convert(constant_3) broadcast_2 = f32[2,4] broadcast(convert_6), dimensions={} - divide_1= f32[2,4] divide(reduce_1, broadcast_2) + divide_1 = f32[2,4] divide(reduce_1, broadcast_2) convert_7 = f16[2,4] convert(divide_1) reshape_2 = f16[2,4,1] reshape(convert_7) rsqrt_0 = f16[2,4,1] rsqrt(reshape_2) @@ -249,13 +249,13 @@ TEST_F(LayerNormTest, LayerNormTest2_F16) { broadcast_3 = f16[2,4,8] broadcast(reshape_3), dimensions={0,1} constant_4 = f16[8] constant({1,1,1,1,1,1,1,1}) broadcast_4 = f16[2,4,8] broadcast(constant_4), dimensions={2} - multiply_1 = f16[2,4,8] multiply(broadcast3, broadcast_4) + multiply_1 = f16[2,4,8] multiply(broadcast_3, broadcast_4) multiply_2 = f16[2,4,8] multiply(multiply_1, Arg_2) constant_5 = f16[8] constant({1,1,1,1,1,1,1,1}) broadcast_5 = f16[2,4,8] broadcast(constant_5), dimensions={2} reshape_4 = f16[2,4] reshape(reshape_0) - broadcast_5 = f16[2,4,8] broadcast(reshape_4), dimensions={0,1} - multiply_3 = f16[2,4,8] multiply(multiply_1, broadcast_5) + broadcast_6 = f16[2,4,8] broadcast(reshape_4), dimensions={0,1} + multiply_3 = f16[2,4,8] multiply(multiply_1, broadcast_6) subtract_1 = f16[2,4,8] subtract(broadcast_5, multiply_3) ROOT add_1 = f16[2,4,8] add(multiply_2, subtract_1) } From 62f4a04a28944b9c4f2f5894b4986e818803ab6c Mon Sep 17 00:00:00 2001 From: sachinmuradi Date: Tue, 30 Jul 2024 10:26:28 -0700 Subject: [PATCH 270/376] PR #15456: [XLA:CPU][]oneDNN]Add numerical correctness test for onednn softmax Imported from GitHub PR https://github.com/openxla/xla/pull/15456 This PR is follow up to https://github.com/openxla/xla/pull/12537#discussion_r1609449939 Request from Benjamin was to separate tests in 3 parts : 1) Just pattern matching test 2) Numerical correctness test ( run onednn$softmax and HLO pattern without fusing and check accuracy) 3) Test to make sure the OneDnnOpsRewriter is run when we call whole CPU compilation pipeline (Need to check with Benjamin regarding this) We already had 1 covered in previously merged softmax [PR](https://github.com/openxla/xla/pull/12537), this PR will address 2. For 3, need some feedback/guidance on how to test the pipeline. Copybara import of the project: -- 7a7bb636aa338e8e05760c18c36d66e7242cec03 by Sachin Muradi : Add numerical correctness tes for onednn softmax Merging this change closes #15456 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15456 from Intel-tensorflow:onednn-softmax-test 7a7bb636aa338e8e05760c18c36d66e7242cec03 PiperOrigin-RevId: 657643323 --- xla/service/cpu/tests/onednn_softmax_test.cc | 111 +++++++++++++------ 1 file changed, 78 insertions(+), 33 deletions(-) diff --git a/xla/service/cpu/tests/onednn_softmax_test.cc b/xla/service/cpu/tests/onednn_softmax_test.cc index 124b4472024c17..b7e43731cfb673 100644 --- a/xla/service/cpu/tests/onednn_softmax_test.cc +++ b/xla/service/cpu/tests/onednn_softmax_test.cc @@ -52,8 +52,52 @@ class OneDnnSoftmaxTest ; CHECK: custom_call_target="__onednn$softmax" )"; + // Get raw HLO text for generic softmax pattern, after replacing $0 with + // datatype and $1 with batch size. + const std::string GetGenericSoftmaxHLORawText(PrimitiveType data_type, + int batch_size) { + const std::string softmax_hlo_template_string = R"( + HloModule softmax_module + region_max { + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(Arg_0, Arg_1) + } + region_add { + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT add = $0[] add(Arg_0, Arg_1) + } + ENTRY main { + Arg_0 = $0[$1,128,30522]{2,1,0} parameter(0) + neg_inf = $0[] constant(-inf) + reduce_max = $0[$1,128]{1,0} reduce(Arg_0, neg_inf), dimensions={2}, to_apply=region_max + reshape.0 = $0[$1,128,1]{2,1,0} reshape(reduce_max) + broadcast.0 = $0[$1,128,1]{2,1,0} broadcast(reshape.0), dimensions={0,1,2} + reshape.1 = $0[$1,128]{1,0} reshape(broadcast.0) + broadcast.1 = $0[$1,128,30522]{2,1,0} broadcast(reshape.1), dimensions={0,1} + subtract.0 = $0[$1,128,30522]{2,1,0} subtract(Arg_0, broadcast.1) + exponential = $0[$1,128,30522]{2,1,0} exponential(subtract.0) + const_zero = $0[] constant(0) + reduce_add = $0[$1,128]{1,0} reduce(exponential, const_zero), dimensions={2}, to_apply=region_add + reshape.2 = $0[$1,128,1]{2,1,0} reshape(reduce_add) + broadcast.2 = $0[$1,128,1]{2,1,0} broadcast(reshape.2), dimensions={0,1,2} + reshape.3 = $0[$1,128]{1,0} reshape(broadcast.2) + broadcast.3 = $0[$1,128,30522]{2,1,0} broadcast(reshape.3), dimensions={0,1} + ROOT divide = $0[$1,128,30522]{2,1,0} divide(exponential, broadcast.3) + } + )"; + + const std::string softmax_hlo_string = absl::Substitute( + softmax_hlo_template_string, + primitive_util::LowercasePrimitiveTypeName(data_type), batch_size); + + return softmax_hlo_string; + } + // Test pattern match with OneDnnOpsRewriter pass - void TestSoftmax(std::string input_hlo_string, int expected_softmax_axis) { + void TestSoftmaxPatternMatching(std::string input_hlo_string, + int expected_softmax_axis) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(input_hlo_string)); OneDnnOpsRewriter softmax_rewrite_pass; @@ -74,6 +118,7 @@ class OneDnnSoftmaxTest }; // Softmax test with last dimension as axis. In this case, axis = 2 +// This test is to make sure the pattern matching works as expected TEST_P(OneDnnSoftmaxTest, SoftmaxGenericTest) { PrimitiveType data_type; int batch_size; @@ -82,44 +127,44 @@ TEST_P(OneDnnSoftmaxTest, SoftmaxGenericTest) { GTEST_SKIP() << "CPU does not support " << primitive_util::LowercasePrimitiveTypeName(data_type); } + const std::string softmax_hlo_string = + GetGenericSoftmaxHLORawText(data_type, batch_size); + + TestSoftmaxPatternMatching(softmax_hlo_string, /*expected_softmax_axis*/ 2); +} + +// Generic Softmax test with last dimension as axis. In this case, axis = 2 +// This test to make sure the accuracy is fine with onednn softmax custom call +TEST_P(OneDnnSoftmaxTest, SoftmaxGenericNumericalCorrectnessTest) { + PrimitiveType data_type; + int batch_size; + std::tie(data_type, batch_size) = GetParam(); + if (!IsSupportedType(data_type)) { + GTEST_SKIP() << "CPU does not support " + << primitive_util::LowercasePrimitiveTypeName(data_type); + } - const std::string softmax_hlo_template_string = R"( + const std::string onednn_softmax_hlo_template_string = R"( HloModule softmax_module - region_max { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - ROOT maximum = $0[] maximum(Arg_0, Arg_1) - } - region_add { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - ROOT add = $0[] add(Arg_0, Arg_1) - } ENTRY main { Arg_0 = $0[$1,128,30522]{2,1,0} parameter(0) - neg_inf = $0[] constant(-inf) - reduce_max = $0[$1,128]{1,0} reduce(Arg_0, neg_inf), dimensions={2}, to_apply=region_max - reshape.0 = $0[$1,128,1]{2,1,0} reshape(reduce_max) - broadcast.0 = $0[$1,128,1]{2,1,0} broadcast(reshape.0), dimensions={0,1,2} - reshape.1 = $0[$1,128]{1,0} reshape(broadcast.0) - broadcast.1 = $0[$1,128,30522]{2,1,0} broadcast(reshape.1), dimensions={0,1} - subtract.0 = $0[$1,128,30522]{2,1,0} subtract(Arg_0, broadcast.1) - exponential = $0[$1,128,30522]{2,1,0} exponential(subtract.0) - const_zero = $0[] constant(0) - reduce_add = $0[$1,128]{1,0} reduce(exponential, const_zero), dimensions={2}, to_apply=region_add - reshape.2 = $0[$1,128,1]{2,1,0} reshape(reduce_add) - broadcast.2 = $0[$1,128,1]{2,1,0} broadcast(reshape.2), dimensions={0,1,2} - reshape.3 = $0[$1,128]{1,0} reshape(broadcast.2) - broadcast.3 = $0[$1,128,30522]{2,1,0} broadcast(reshape.3), dimensions={0,1} - ROOT divide = $0[$1,128,30522]{2,1,0} divide(exponential, broadcast.3) + ROOT custom-call = $0[$1,128,30522]{2,1,0} custom-call(Arg_0), custom_call_target="$2", backend_config={"onednn_softmax_config":{"softmax_axis":2}} } )"; - const std::string softmax_hlo_string = absl::Substitute( - softmax_hlo_template_string, - primitive_util::LowercasePrimitiveTypeName(data_type), batch_size); + auto onednn_softmax_hlo_string = + absl::Substitute(onednn_softmax_hlo_template_string, + primitive_util::LowercasePrimitiveTypeName(data_type), + batch_size, "__onednn$softmax"); + const std::string hlo_string_ref = + GetGenericSoftmaxHLORawText(data_type, batch_size); + + float atol = (data_type == F32) ? 1e-4 : 1e-2; + float rtol = (data_type == F32) ? 1e-4 : 1e-2; - TestSoftmax(softmax_hlo_string, /*expected_softmax_axis*/ 2); + EXPECT_TRUE(RunAndCompareTwoModules(onednn_softmax_hlo_string, hlo_string_ref, + ErrorSpec{atol, rtol}, + /*run_hlo_passes=*/false)); } INSTANTIATE_TEST_SUITE_P(OneDnnSoftmaxTestSuite, OneDnnSoftmaxTest, @@ -163,7 +208,7 @@ TEST_F(OneDnnSoftmaxTest, SoftmaxFP32OnAxisZero) { } )"; - TestSoftmax(softmax_hlo_string, /*expected_softmax_axis*/ 0); + TestSoftmaxPatternMatching(softmax_hlo_string, /*expected_softmax_axis*/ 0); } TEST_F(OneDnnSoftmaxTest, SoftmaxWithBF16ConvertOutputFP32Pattern) { @@ -204,7 +249,7 @@ TEST_F(OneDnnSoftmaxTest, SoftmaxWithBF16ConvertOutputFP32Pattern) { } )"; - TestSoftmax(softmax_hlo_string, /*expected_softmax_axis=*/2); + TestSoftmaxPatternMatching(softmax_hlo_string, /*expected_softmax_axis=*/2); } } // namespace cpu From 2820a54a6ed87304e95bd97a3d6cff4725b5a704 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Tue, 30 Jul 2024 11:10:25 -0700 Subject: [PATCH 271/376] Add original_value field to HloInstruction Add the field to track a value in an optimized graph to its corresponding value in the unoptimized graph. PiperOrigin-RevId: 657661375 --- xla/hlo/ir/BUILD | 2 + xla/hlo/ir/hlo_instruction.cc | 47 +++++++ xla/hlo/ir/hlo_instruction.h | 8 ++ xla/hlo/ir/hlo_original_value.cc | 68 ++++++++++ xla/hlo/ir/hlo_original_value.h | 37 ++++++ xla/service/hlo.proto | 5 +- xla/service/hlo_parser.cc | 216 +++++++++++++++++++++---------- xla/service/hlo_parser_test.cc | 30 +++++ xla/xla_data.proto | 10 ++ 9 files changed, 353 insertions(+), 70 deletions(-) create mode 100644 xla/hlo/ir/hlo_original_value.cc create mode 100644 xla/hlo/ir/hlo_original_value.h diff --git a/xla/hlo/ir/BUILD b/xla/hlo/ir/BUILD index b396fbee70940d..203223870f6e79 100644 --- a/xla/hlo/ir/BUILD +++ b/xla/hlo/ir/BUILD @@ -33,6 +33,7 @@ cc_library( "hlo_module_metadata.cc", "hlo_op_metadata.cc", "hlo_opcode.cc", + "hlo_original_value.cc", "hlo_schedule.cc", "hlo_sharding.cc", "hlo_sharding_metadata.cc", @@ -54,6 +55,7 @@ cc_library( "hlo_module_metadata.h", "hlo_op_metadata.h", "hlo_opcode.h", + "hlo_original_value.h", "hlo_schedule.h", "hlo_sharding.h", "hlo_sharding_metadata.h", diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 6ccbfcdb63a1a5..726113457c2bea 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -60,6 +60,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_op_metadata.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_original_value.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/ir/ptrvec.h" @@ -1225,6 +1226,20 @@ absl::StatusOr> HloInstruction::CreateFromProto( instruction->set_statistics_viz(proto.statistics_viz()); } + if (proto.has_original_value()) { + const xla::OriginalValueProto& original_value_proto = + proto.original_value(); + auto original_value = std::make_shared(shape); + std::cerr << __func__ << ", shape: " << shape.ToString() << "\n"; + + for (const auto& leaf : original_value_proto.leaves()) { + *original_value->mutable_element(ShapeIndex(leaf.leaf_shape_index())) = { + leaf.instruction_name(), ShapeIndex(leaf.shape_index())}; + } + + instruction->set_original_value(original_value); + } + return std::move(instruction); } @@ -3603,6 +3618,12 @@ void HloInstruction::PrintWithCanonicalNameMap( }); PrintExtraAttributes(attr_printer, options); + if (original_value_) { + printer->Append(", original_value={"); + printer->Append(OriginalValueToString(*original_value())); + printer->Append("}"); + } + if (options.print_metadata() && (!metadata_->op_type().empty() || !metadata_->op_name().empty() || !metadata_->source_file().empty() || @@ -3972,6 +3993,23 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_statistics_viz() = statistics_viz(); + if (original_value_) { + xla::OriginalValueProto* original_value_proto = + proto.mutable_original_value(); + for (const auto& leaf : original_value_->leaves()) { + OriginalArrayProto* original_array_proto = + original_value_proto->add_leaves(); + for (const auto& index : leaf.first) { + original_array_proto->add_leaf_shape_index(index); + } + *original_array_proto->mutable_instruction_name() = + leaf.second->instruction_name; + for (const auto& index : leaf.second->shape_index) { + original_array_proto->add_shape_index(index); + } + } + } + return proto; } @@ -5479,4 +5517,13 @@ void HloInstruction::set_output_to_operand_aliasing( std::move(aliasing)); } +std::shared_ptr HloInstruction::original_value() const { + return original_value_; +} + +void HloInstruction::set_original_value( + std::shared_ptr original_value) { + original_value_ = original_value; +} + } // namespace xla diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index 337e9ff534eb84..6d8821c683692f 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -52,6 +52,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_original_value.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/ptrvec.h" #include "xla/layout.h" @@ -2549,6 +2550,9 @@ class HloInstruction { HloInstruction(const HloInstruction&) = delete; HloInstruction& operator=(const HloInstruction&) = delete; + std::shared_ptr original_value() const; + void set_original_value(std::shared_ptr original_value); + protected: // Internal constructor for a given opcode/shape, other fields must be filled // by factory methods. @@ -2799,6 +2803,10 @@ class HloInstruction { // String identifier for instruction. std::string name_; + // Original value this instruction corresponds to in the unoptimized HLO + // graph. + std::shared_ptr original_value_ = nullptr; + // Metadata for debugging. Allocate it on heap, so that it does not increase // the memory footprint of HloInstruction. std::unique_ptr metadata_ = std::make_unique(); diff --git a/xla/hlo/ir/hlo_original_value.cc b/xla/hlo/ir/hlo_original_value.cc new file mode 100644 index 00000000000000..789978d74cbf39 --- /dev/null +++ b/xla/hlo/ir/hlo_original_value.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_original_value.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace xla { + +std::string OriginalValueToStringHelper(const OriginalValue& original_value, + const Shape& shape, + std::vector& shape_index) { + std::string result; + if (shape.IsTuple()) { + if (shape.tuple_shapes().empty()) { + return "()"; + } + absl::StrAppend(&result, "("); + shape_index.push_back(0); + absl::StrAppend(&result, + OriginalValueToStringHelper( + original_value, shape.tuple_shapes(0), shape_index)); + shape_index.pop_back(); + for (int64_t i = 1; i < shape.tuple_shapes().size(); ++i) { + absl::StrAppend(&result, ", "); + shape_index.push_back(i); + absl::StrAppend(&result, + OriginalValueToStringHelper( + original_value, shape.tuple_shapes(i), shape_index)); + shape_index.pop_back(); + } + absl::StrAppend(&result, ")"); + return result; + } + + const auto& leaf = original_value.element(shape_index); + absl::StrAppend( + &result, "{", "\"", leaf->instruction_name, "\"", + (leaf->shape_index.empty() ? "" : " " + leaf->shape_index.ToString()), + "}"); + return result; +} + +std::string OriginalValueToString(const OriginalValue& original_value) { + std::vector shape_index; + return OriginalValueToStringHelper(original_value, original_value.shape(), + shape_index); +} +} // namespace xla diff --git a/xla/hlo/ir/hlo_original_value.h b/xla/hlo/ir/hlo_original_value.h new file mode 100644 index 00000000000000..a77bc8a13460c7 --- /dev/null +++ b/xla/hlo/ir/hlo_original_value.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_IR_HLO_ORIGINAL_VALUE_H_ +#define XLA_HLO_IR_HLO_ORIGINAL_VALUE_H_ + +#include +#include + +#include "xla/shape_tree.h" +#include "xla/shape_util.h" + +namespace xla { +// Stores information of original values. +struct OriginalArray { + std::string instruction_name; + ShapeIndex shape_index; +}; + +using OriginalValue = ShapeTree>; + +std::string OriginalValueToString(const OriginalValue& original_value); +} // namespace xla + +#endif // XLA_HLO_IR_HLO_ORIGINAL_VALUE_H_ diff --git a/xla/service/hlo.proto b/xla/service/hlo.proto index 083fd7d2b3fac8..fdeaa68c614166 100644 --- a/xla/service/hlo.proto +++ b/xla/service/hlo.proto @@ -112,7 +112,7 @@ enum CustomCallApiVersion { } // Serialization of HloInstruction. -// Next ID: 88 +// Next ID: 89 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -382,6 +382,9 @@ message HloInstructionProto { // Represents the list of devices that participate in a collective operation. xla.CollectiveDeviceListProto collective_device_list = 87; + + // For HLO value tracking. + xla.OriginalValueProto original_value = 88; } // Serialization of HloComputation. diff --git a/xla/service/hlo_parser.cc b/xla/service/hlo_parser.cc index ab144fa6eb34da..2c9a480983afa5 100644 --- a/xla/service/hlo_parser.cc +++ b/xla/service/hlo_parser.cc @@ -54,6 +54,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_original_value.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" @@ -311,6 +312,7 @@ class HloParserImpl : public HloParser { // enclosed in matching curly braces (returned value includes the curlies). kStringOrJsonDict, kCollectiveDeviceList, + kOriginalValue, }; struct AttrConfig { @@ -446,7 +448,7 @@ class HloParserImpl : public HloParser { // bool ParseAttributes( const absl::flat_hash_map& attrs, - bool allow_attributes = true); + bool allow_attributes = true, const std::optional& shape = {}); // sub_attributes ::= '{' (','? attribute)* '}' // @@ -460,7 +462,8 @@ class HloParserImpl : public HloParser { // Do not call this except in ParseAttributes or ParseSubAttributes. bool ParseAttributeHelper( const absl::flat_hash_map& attrs, - absl::flat_hash_set* seen_attrs); + absl::flat_hash_set* seen_attrs, + const std::optional& shape = {}); // Copy attributes from `attrs` to `message`, unless the attribute name is in // `non_proto_attrs`. @@ -564,6 +567,9 @@ class HloParserImpl : public HloParser { bool ParseBool(bool* result); bool ParseToken(TokKind kind, const std::string& msg); bool ParseUnsignedIntegerType(PrimitiveType* primitive_type); + bool ParseOriginalValue( + optional>* original_value, + const Shape& shape); using AliasingData = absl::flat_hash_map; @@ -1371,6 +1377,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, optional> predecessors; attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, &predecessors}; + + optional> original_value; + attrs["original_value"] = {/*required=*/false, AttrTy::kOriginalValue, + &original_value}; + optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; @@ -1440,6 +1451,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (original_value) { + instruction->set_original_value(*original_value); + } if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } @@ -1492,7 +1506,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT return nullptr; } if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } std::string param_name(name); @@ -1510,7 +1524,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT "expects '(' before constant literal") || !ParseLiteral(&literal, *shape) || !ParseToken(TokKind::kRparen, "expects ')' after constant literal") || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -1522,7 +1536,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &iota_dimension}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/0)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -1535,7 +1549,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["largest"] = {/*required=*/false, AttrTy::kBool, &largest}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -1582,7 +1596,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kTanh: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -1613,7 +1627,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kStochasticConvert: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -1630,7 +1644,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kSelect: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/3)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -1646,7 +1660,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kConvert: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -1655,7 +1669,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kBitcastConvert: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -1678,7 +1692,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool, &use_global_device_ids}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (opcode == HloOpcode::kAllGather) { @@ -1715,7 +1729,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; } if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (opcode == HloOpcode::kAllReduce) { @@ -1748,7 +1762,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool, &constrain_layout}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes) || + !ParseAttributes(attrs, allow_attributes, shape) || (dimensions && dimensions->size() != 1)) { return nullptr; } @@ -1768,7 +1782,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT optional channel_id; attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateCollectiveBroadcast( @@ -1785,7 +1799,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["slice_sizes"] = {/*required=*/false, AttrTy::kBracedInt64ListList, &slice_sizes}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } std::vector> pairs(source_targets->size()); @@ -1941,7 +1955,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT // Attributes would have already been consumed when constructing the // async wrapped computation for async-start. if (!(async_wrapped_opcode && opcode == HloOpcode::kAsyncStart)) { - if (!ParseAttributes(attrs, allow_attributes)) { + if (!ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } } @@ -1999,7 +2013,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT /*required=*/false, AttrTy::kInt32, &cross_program_prefetch_index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateCopyStart( @@ -2008,7 +2022,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kReplicaId: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/0)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (shape.has_value()) { @@ -2019,7 +2033,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kPartitionId: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/0)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (shape.has_value()) { @@ -2030,7 +2044,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } case HloOpcode::kDynamicReshape: { if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateDynamicReshape( @@ -2043,7 +2057,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &inferred_dimension}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateReshape( @@ -2051,7 +2065,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } case HloOpcode::kAfterAll: { if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operands.empty()) { @@ -2062,7 +2076,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kAddDependency: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -2078,7 +2092,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes) || + !ParseAttributes(attrs, allow_attributes, shape) || dimensions->size() != 1) { return nullptr; } @@ -2101,7 +2115,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT !(shape.has_value() ? ParseOperands(&operands, builder, shape->tuple_shapes_size()) : ParseOperands(&operands, builder))) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2127,7 +2141,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2149,7 +2163,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &is_host_transfer}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } // If the is_host_transfer attribute is not present then default to false. @@ -2165,7 +2179,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &is_host_transfer}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -2187,7 +2201,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &is_host_transfer}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateSend( @@ -2202,7 +2216,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &is_host_transfer}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -2220,7 +2234,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2237,7 +2251,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2261,7 +2275,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!window) { @@ -2305,7 +2319,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &operand_precision}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!window) { @@ -2346,7 +2360,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &fft_length}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2383,7 +2397,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2423,7 +2437,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/!operand_is_scalar, AttrTy::kBracedInt64List, &broadcast_dimensions}; - if (!ParseAttributes(attrs, allow_attributes)) { + if (!ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operand_is_scalar && !broadcast_dimensions.has_value()) { @@ -2444,7 +2458,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes) || + !ParseAttributes(attrs, allow_attributes, shape) || dimensions->size() != 1) { return nullptr; } @@ -2470,7 +2484,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2496,7 +2510,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operands.size() % 2) { @@ -2531,7 +2545,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2552,7 +2566,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/3)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!window) { @@ -2575,7 +2589,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateSlice( @@ -2587,7 +2601,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operands.empty()) { @@ -2606,7 +2620,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } case HloOpcode::kDynamicUpdateSlice: { if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operands.size() < 2) { @@ -2628,7 +2642,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2648,7 +2662,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &feature_index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/3)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2670,7 +2684,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &feature_index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/5)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2694,7 +2708,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &feature_index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/5)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2715,7 +2729,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2740,7 +2754,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT AttrTy::kInstructionAliasing, &output_to_operand_aliasing}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } auto instr = builder->AddInstruction(HloInstruction::CreateFusion( @@ -2757,7 +2771,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } // We need to know the infeed data shape to construct the infeed @@ -2781,7 +2795,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &outfeed_shape}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } HloInstruction* const outfeed_input = operands[0]; @@ -2796,7 +2810,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution, &distribution}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -2807,7 +2821,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/0)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -2818,7 +2832,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm, &algorithm}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateRngBitGenerator( @@ -2833,7 +2847,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &mantissa_bits}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateReducePrecision( @@ -2867,7 +2881,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT AttrTy::kBracedHloComputationList, &branch_computations}; } - if (!ParseAttributes(attrs, allow_attributes)) { + if (!ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (branch_index_is_bool) { @@ -2958,7 +2972,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["api_version"] = {/*required=*/false, AttrTy::kCustomCallApiVersion, &api_version}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -3088,7 +3102,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT LocTy loc = lexer_.GetLoc(); if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -3166,7 +3180,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -3213,7 +3227,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &unique_indices}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -3254,7 +3268,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -3272,7 +3286,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -3290,7 +3304,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -4626,12 +4640,12 @@ bool HloParserImpl::ParseSubAttributes( // attributes ::= (',' attribute)* bool HloParserImpl::ParseAttributes( const absl::flat_hash_map& attrs, - bool allow_attributes) { + bool allow_attributes, const std::optional& shape) { LocTy loc = lexer_.GetLoc(); absl::flat_hash_set seen_attrs; if (allow_attributes) { while (EatIfPresent(TokKind::kComma)) { - if (!ParseAttributeHelper(attrs, &seen_attrs)) { + if (!ParseAttributeHelper(attrs, &seen_attrs, shape)) { return false; } } @@ -4645,12 +4659,14 @@ bool HloParserImpl::ParseAttributes( attr_it.first)); } } + return true; } bool HloParserImpl::ParseAttributeHelper( const absl::flat_hash_map& attrs, - absl::flat_hash_set* seen_attrs) { + absl::flat_hash_set* seen_attrs, + const std::optional& shape) { LocTy loc = lexer_.GetLoc(); std::string name; if (!ParseAttributeName(&name)) { @@ -4929,6 +4945,17 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(std::move(result)); return true; } + case AttrTy::kOriginalValue: { + // By the time this attribute is added, the instruciton shape should + // have been inferred. + if (!shape) { + return TokenError("expects instruction shape"); + } + return ParseOriginalValue( + static_cast>*>( + attr_out_ptr), + *shape); + } case AttrTy::kMetadata: { OpMetadata result; if (!ParseMetadata(&result)) { @@ -6225,6 +6252,57 @@ bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) { return true; } +// original_value ::= original_value | '{' [shape_index] ',' original_array '}' +// [','] +bool HloParserImpl::ParseOriginalValue( + optional>* original_value, + const Shape& shape) { + VLOG(3) << "ParseOriginalValue"; + + if (!ParseToken(TokKind::kLbrace, "Expects '{'")) { + return false; + } + + *original_value = std::make_shared(shape); + + ShapeIndex leaf_shape_index; + while (lexer_.GetKind() != TokKind::kRbrace) { + if (lexer_.GetKind() == TokKind::kLparen) { + lexer_.Lex(); + leaf_shape_index.push_back(0); + } else if (lexer_.GetKind() == TokKind::kRparen) { + lexer_.Lex(); + leaf_shape_index.pop_back(); + } else if (lexer_.GetKind() == TokKind::kComma) { + lexer_.Lex(); + ++leaf_shape_index.back(); + } else if (lexer_.GetKind() == TokKind::kLbrace) { + lexer_.Lex(); + std::string instruction_name; + ShapeIndex shape_index; + if (!ParseString(&instruction_name)) { + return false; + } + if (lexer_.GetKind() != TokKind::kRbrace) { + if (!ParseShapeIndex(&shape_index)) { + return false; + } + } + *(**original_value)->mutable_element(leaf_shape_index) = { + instruction_name, shape_index}; + if (!ParseToken(TokKind::kRbrace, + "Expects '} at end of each OriginalArray'")) { + return false; + } + } else { + return false; + } + } + + lexer_.Lex(); + return true; +} + // '{' metadata_string '}' bool HloParserImpl::ParseMetadata(OpMetadata* metadata) { absl::flat_hash_map attrs; diff --git a/xla/service/hlo_parser_test.cc b/xla/service/hlo_parser_test.cc index 1f50e26133b9e3..54d24c0436a256 100644 --- a/xla/service/hlo_parser_test.cc +++ b/xla/service/hlo_parser_test.cc @@ -1390,6 +1390,21 @@ ENTRY %test (p: f32[100]) -> u32[100] { ROOT %root = u32[100]{0} bitcast-convert(f32[100]{0} %p), metadata={op_type="a" op_name="b" source_file="c" source_line=1 profile_type={1} deduplicated_name="d" preserve_layout=true} } +)" +}, + +{ +"OriginalValue", +R"(HloModule test, entry_computation_layout={(f32[], f32[3]{0}, f32[2,3]{1,0})->((f32[], f32[3]{0}), f32[2,3]{1,0})} + +ENTRY %test (v1: f32[], v2: f32[3], v3: f32[2,3]) -> ((f32[], f32[3]), f32[2,3]) { + %v1 = f32[] parameter(0), original_value={{"v1"}} + %v2 = f32[3]{0} parameter(1), original_value={{"v2"}} + %tuple = (f32[], f32[3]{0}) tuple(f32[] %v1, f32[3]{0} %v2), original_value={({"v1"}, {"v2"})} + %v3 = f32[2,3]{1,0} parameter(2), original_value={{"v3"}} + ROOT %nested_tuple = ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple((f32[], f32[3]{0}) %tuple, f32[2,3]{1,0} %v3), original_value={(({"v1"}, {"v2"}), {"v3"})} +} + )" }, }); @@ -5360,5 +5375,20 @@ TEST_F(HloParserTest, ReplicaIdWithLayout) { .empty()); } +TEST_F(HloParserTest, OriginalValueWithoutShape) { + const std::string hlo_string = R"(HloModule test + +ENTRY %test { + %a = f32[2,10]{1,0} parameter(0), original_value={{"a"}} + ROOT %v = abs(%a), original_value={{"v"}} +} + + +)"; + EXPECT_THAT(ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("expects instruction shape"))); +} + } // namespace } // namespace xla diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 4c7d47b1bf66b9..d3943614332abc 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -1087,3 +1087,13 @@ message OutputOperandAliasing { int64 operand_index = 2; repeated int64 operand_shape_index = 3; } + +message OriginalArrayProto { + repeated int64 leaf_shape_index = 1; + string instruction_name = 2; + repeated int64 shape_index = 3; +} + +message OriginalValueProto { + repeated OriginalArrayProto leaves = 1; +} From f2c229feb7d645f76303b351a21d7cef04cd7af7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 11:17:29 -0700 Subject: [PATCH 272/376] Generalize the computation of the default replication penalty to allow mesh shapes of any size. PiperOrigin-RevId: 657663933 --- xla/hlo/experimental/auto_sharding/BUILD | 2 ++ .../auto_sharding/auto_sharding_strategy.cc | 3 +-- .../auto_sharding/cluster_environment.h | 14 ++++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/BUILD b/xla/hlo/experimental/auto_sharding/BUILD index 258bc53fc2d2ca..4ad33554f98adf 100644 --- a/xla/hlo/experimental/auto_sharding/BUILD +++ b/xla/hlo/experimental/auto_sharding/BUILD @@ -251,9 +251,11 @@ cc_library( ":auto_sharding_strategy", ":auto_sharding_util", ":profiling_result", + "//xla:array", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service/spmd:spmd_partitioner", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 6c4ae8251033b9..4c6311a111f467 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -149,8 +149,7 @@ BuildStrategyAndCost( const std::vector& instructions = sequence.instructions(); // Add penalty for replicated tensors - double replicated_penalty = std::round(cluster_env.AllReduceCost(1, 0) + - cluster_env.AllReduceCost(1, 1)); + double replicated_penalty = cluster_env.GetDefaultReplicatedPenalty(); int64_t max_depth = -1; for (auto iter : depth_map) { diff --git a/xla/hlo/experimental/auto_sharding/cluster_environment.h b/xla/hlo/experimental/auto_sharding/cluster_environment.h index 19736d19e25f0a..a70570209350b2 100644 --- a/xla/hlo/experimental/auto_sharding/cluster_environment.h +++ b/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -17,16 +17,22 @@ limitations under the License. #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_CLUSTER_ENVIRONMENT_H_ #include +#include #include #include #include #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/experimental/auto_sharding/profiling_result.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/shape.h" namespace xla { namespace spmd { @@ -121,6 +127,14 @@ class ClusterEnvironment { return tensor_dim_to_mesh_dim; } + double GetDefaultReplicatedPenalty() const { + double replicated_penalty = 0; + for (int i = 0; i < device_mesh_.num_dimensions(); ++i) { + replicated_penalty += AllReduceCost(1, i); + } + return std::round(replicated_penalty); + } + double AllGatherCost(double num_bytes, int mesh_dim) const; double AllReduceCost(double num_bytes, int32_t mesh_dim, From 07023194ce5f60ee56cdf9410dee644e8d3fb5af Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 11:20:43 -0700 Subject: [PATCH 273/376] Simplify code for generating collective matmul strategies a little and remove two now dead functions PiperOrigin-RevId: 657665151 --- .../auto_sharding_dot_handler.cc | 74 +++++-------------- 1 file changed, 20 insertions(+), 54 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 3c62712ab41e27..969fdf30f1bfbb 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -151,23 +151,6 @@ class HandlerBase { std::optional GetShardingFromUser(const HloSharding& lhs_spec, const HloSharding& rhs_spec); - // Enumerates combinations of the given mesh + tensor dimensions. - void Enumerate(std::function split_func, - size_t num_outer_dims = 2, size_t num_inner_dims = 2, - bool half = false) { - absl::Span mesh_shape = device_mesh_.dimensions(); - for (int64_t dim0 = 0; dim0 < mesh_shape.size(); ++dim0) { - for (int64_t dim1 = 0; dim1 < mesh_shape.size(); ++dim1) { - if (dim0 == dim1) continue; - for (int64_t i = 0; i < num_outer_dims; ++i) { - for (int64_t j = half ? i + 1 : 0; j < num_inner_dims; ++j) { - split_func({{dim0, dim1}, i, j}); - } - } - } - } - } - // Given a set of tensor dims, and a set of mesh dims, enumerates all mappings // where a subset of all tensor dims is mapped to a subset of mesh dims, such // that each tensor dim is mapped to at most mesh dim, and no two tensor dims @@ -198,12 +181,6 @@ class HandlerBase { } } - // Enumerates *half* of the combinations (if inner & outer dims are the same). - void EnumerateHalf(std::function split_func, - size_t num_outer_dims = 2, size_t num_inner_dims = 2) { - Enumerate(split_func, num_outer_dims, num_inner_dims, true); - } - // Sorts strategies in the increasing order of their memory costs. Anecdotal // experience suggests that such a sorted list of strategies works better void SortStrategies(); @@ -743,50 +720,43 @@ void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand( const Array& device_mesh, double compute_cost) { const HloInstruction* operand = ins_->operand(operand_num); const DimMap& operand_dim_map = operand_num == 0 ? lhs_dim_map : rhs_dim_map; - absl::flat_hash_set sharded_tensor_dims; absl::flat_hash_set used_mesh_dims; for (const auto [tensor_dim, mesh_dim] : operand_dim_map) { - if (device_mesh.dim(mesh_dim) == 1) { - continue; - } - sharded_tensor_dims.insert(tensor_dim); used_mesh_dims.insert(mesh_dim); } if (used_mesh_dims.size() == device_mesh_.num_dimensions() || - sharded_tensor_dims.size() == operand->shape().rank()) { + used_mesh_dims.size() == operand->shape().rank()) { return; } for (int64_t tensor_dim = 0; tensor_dim < operand->shape().rank(); ++tensor_dim) { - if (sharded_tensor_dims.contains(tensor_dim)) { + if (auto it = operand_dim_map.find(tensor_dim); + it != operand_dim_map.end() && device_mesh.dim(it->second) > 1) { continue; } for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions(); ++mesh_dim) { - if (used_mesh_dims.contains(mesh_dim) || - (device_mesh.dim(mesh_dim) == 1)) { + if (used_mesh_dims.contains(mesh_dim)) { continue; } DimMap further_sharded_dim_map = operand_dim_map; further_sharded_dim_map[tensor_dim] = mesh_dim; - auto updated_communication_cost_fn = + auto communication_cost_fn = [](const HloSharding& output_sharding) -> double { // TODO(331684721): Model costs for windowed einsum return 100.0; }; - std::string updated_name = - absl::StrCat(absl::StrFormat("WindowedEinsum @ {%d,%d,%d}", - operand_num, tensor_dim, mesh_dim), - name); + std::string updated_name = absl::StrCat( + name, absl::StrFormat("|ag_windowed_einsum_o%dt%dm%d", operand_num, + tensor_dim, mesh_dim)); MaybeAppendInternal( updated_name, operand_num == 0 ? further_sharded_dim_map : lhs_dim_map, operand_num == 1 ? further_sharded_dim_map : rhs_dim_map, - output_dim_map, device_mesh, compute_cost, - updated_communication_cost_fn); + output_dim_map, device_mesh, compute_cost, communication_cost_fn); } } } @@ -795,46 +765,42 @@ void DotHandler::AppendReduceScatterWindowedEinsumStrategy( const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, const DimMap& output_dim_map, const Array& device_mesh, double compute_cost) { - absl::flat_hash_set sharded_tensor_dims; absl::flat_hash_set used_mesh_dims; for (const auto [tensor_dim, mesh_dim] : output_dim_map) { - if (device_mesh.dim(mesh_dim) == 1) { - continue; - } - sharded_tensor_dims.insert(tensor_dim); used_mesh_dims.insert(mesh_dim); } + if (used_mesh_dims.size() == device_mesh_.num_dimensions() || - sharded_tensor_dims.size() == ins_->shape().rank()) { + used_mesh_dims.size() == ins_->shape().rank()) { return; } for (int64_t tensor_dim = 0; tensor_dim < ins_->shape().rank(); ++tensor_dim) { - if (sharded_tensor_dims.contains(tensor_dim)) { + if (auto it = output_dim_map.find(tensor_dim); + it != output_dim_map.end() && device_mesh.dim(it->second) > 1) { continue; } for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions(); ++mesh_dim) { - if (used_mesh_dims.contains(mesh_dim) || - (device_mesh.dim(mesh_dim) == 1)) { + if (used_mesh_dims.contains(mesh_dim)) { continue; } DimMap further_sharded_dim_map = output_dim_map; further_sharded_dim_map[tensor_dim] = mesh_dim; - auto updated_communication_cost_fn = + auto communication_cost_fn = [](const HloSharding& output_sharding) -> double { // TODO(331684721): Model costs for windowed einsum return 100.0; }; std::string updated_name = absl::StrCat( - absl::StrFormat("WindowedEinsum @ {%d,%d}", tensor_dim, mesh_dim), - name); + name, + absl::StrFormat("|rs_windowed_einsum_t%dm%d", tensor_dim, mesh_dim)); MaybeAppendInternal(updated_name, lhs_dim_map, rhs_dim_map, further_sharded_dim_map, device_mesh, compute_cost, - updated_communication_cost_fn); + communication_cost_fn); } } } @@ -1014,7 +980,7 @@ void ConvHandler::SplitDepthwise(bool forward) { rhs_dim_map[rhs_out_channel_dim_] = out_out_channel_mesh_dim; MaybeAppend(absl::StrCat("b", out_batch_mesh_dim, "oc", - out_out_channel_mesh_dim, "@depthwise"), + out_out_channel_mesh_dim, "|depthwise"), lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_); }; absl::flat_hash_set all_mesh_dims; @@ -1062,7 +1028,7 @@ absl::Status HandleConv(std::unique_ptr& strategy_group, strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); - auto conv_as_dot_dims = + const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(ins); if (conv_as_dot_dims.conv_spatial_dims.empty()) { DotHandler handler( From 39ac14eb812889131577a5d08ef3989a2bacab25 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Tue, 30 Jul 2024 11:32:25 -0700 Subject: [PATCH 274/376] Cleanup for exhaustive_unary_f32_or_smaller_test Changes: - `std::is_same_v` instead of `std::is_same()` - Fixed all lints about missing/unused includes. - Fixed fn name style. - Split ReciprocalTpuAbsErr into one fn for TPUs and one for CPU/GPU. PiperOrigin-RevId: 657669963 --- xla/tests/exhaustive/BUILD | 1 + .../exhaustive/exhaustive_op_test_utils.cc | 1 + .../exhaustive_unary_f32_or_smaller_test.cc | 94 +++++++++++-------- 3 files changed, 57 insertions(+), 39 deletions(-) diff --git a/xla/tests/exhaustive/BUILD b/xla/tests/exhaustive/BUILD index 7dea3db1d7ee5e..89220e327497c1 100644 --- a/xla/tests/exhaustive/BUILD +++ b/xla/tests/exhaustive/BUILD @@ -49,6 +49,7 @@ cc_library( "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", ], ) diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/xla/tests/exhaustive/exhaustive_op_test_utils.cc index 991c550e8b925f..17964eccfa32fd 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -35,6 +35,7 @@ limitations under the License. #include "Eigen/Core" #include "xla/literal.h" #include "xla/types.h" +#include "tsl/platform/test.h" namespace xla { namespace exhaustive_op_test { diff --git a/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc b/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc index 3e175ea054139c..70511ecbb46363 100644 --- a/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc +++ b/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -635,76 +634,93 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(RoundNearestEven, { fesetround(curr_direction); }) +// Can be thought of as an absolute error of `<= +// |std::numeric_limits::min()|`. template -double reciprocal_abs_error(NativeT val) { - double abs_err = 0.0; +double ReciprocalCpuGpuAbsError(NativeT val) { + float output = 1.0f / static_cast(val); - // For subnormals, we need to set absolute error to the smallest positive - // representable value due to hardware implementations that truncate - // subnormals to zero. - bool is_subnormal_output = - std::numeric_limits::denorm_min() <= std::abs(1 / val) && - std::abs(1 / val) <= std::numeric_limits::min(); - if (is_subnormal_output) { - abs_err = std::numeric_limits::min(); + if (IsSubnormal(output)) { + return std::numeric_limits::min(); } - return abs_err; + return 0.0; +} + +// Can be thought of as an absolute error of `<= +// |std::numeric_limits::min()|`. +template +double ReciprocalTpuAbsError(NativeT val) { + float output = 1.0f / static_cast(val); + + // TPU seems to flush subnormals or minimum normal to 0. We set the error to + // the minimum normal in these cases. + if (IsSubnormalOrMinNormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; } UNARY_TEST_FLOAT_32_BITS_OR_LESS(Reciprocal, { ErrorSpecGen error_spec_gen = - +[](NativeT) { return ErrorSpec{.abs_err = 0.0, .rel_err = 0.0}; }; + +[](NativeT) { return ErrorSpec{.strict_signed_zeros = true}; }; if (IsCpu(platform_)) { error_spec_gen = +[](NativeT val) { - return ErrorSpec{.abs_err = reciprocal_abs_error(val), .rel_err = 0.0}; + return ErrorSpec{.abs_err = ReciprocalCpuGpuAbsError(val), + .strict_signed_zeros = true}; }; } if (IsGpu(platform_)) { error_spec_gen = +[](NativeT val) { NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec{.abs_err = reciprocal_abs_error(val), .rel_err = eps}; + return ErrorSpec{.abs_err = ReciprocalCpuGpuAbsError(val), + .rel_err = eps, + .strict_signed_zeros = true}; }; } if (IsTpu(platform_)) { error_spec_gen = +[](NativeT val) { - auto abs_err = reciprocal_abs_error(val); - if constexpr (std::is_same()) { - return ErrorSpec{.abs_err = abs_err, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { - // N.B.: Does not require absolute error. - return ErrorSpec{.abs_err = 0.0, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { + if constexpr (std::is_same_v) { + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { + return ErrorSpec{.strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec{.abs_err = abs_err, .rel_err = eps}; + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .rel_err = eps, + .strict_signed_zeros = true}; } }; } if (IsPreV6Tpu(platform_)) { error_spec_gen = +[](NativeT val) { - auto abs_err = reciprocal_abs_error(val); - if constexpr (std::is_same()) { - return ErrorSpec{.abs_err = abs_err, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { - // N.B.: Does not require absolute error. - return ErrorSpec{.abs_err = 0.0, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { + if constexpr (std::is_same_v) { + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { + return ErrorSpec{.strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec{.abs_err = abs_err, .rel_err = 34 * eps}; + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .rel_err = 34 * eps, + .strict_signed_zeros = true}; } }; } if (IsPreV5Tpu(platform_)) { error_spec_gen = +[](NativeT val) { - auto abs_err = reciprocal_abs_error(val); - if constexpr (std::is_same()) { - return ErrorSpec{.abs_err = abs_err, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { - // N.B.: Does not require absolute error. - return ErrorSpec{.abs_err = 0.0, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { + if constexpr (std::is_same_v) { + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { + return ErrorSpec{.strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec{.abs_err = abs_err, .rel_err = 136 * eps}; + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .rel_err = 136 * eps, + .strict_signed_zeros = true}; } }; } From 99353d88f96cb378562e396c9829ceb3032ff848 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Tue, 30 Jul 2024 11:39:53 -0700 Subject: [PATCH 275/376] PR #15216: Make GpuExecutor::HostMemoryAllocate NUMA aware Imported from GitHub PR https://github.com/openxla/xla/pull/15216 This improves the achieved throughput of D2H transfers, for example when checkpointing. For example, there is a ~2x improvement in throughput of overlapped D2H copies from 8xH100 on a DGX node. Notes: - `TENSORFLOW_USE_NUMA` is set unconditionally instead of being hidden behind an option; it's not clear from OSS-world if this is an important handle for Google internally. - `stream_executor::StreamExecutor::HostMemoryDeallocate` now takes the allocation size; all call sites updated. This is required by the `tsl::port::NUMAFree` API. Copybara import of the project: -- c8bc494a46a5bc192689c0754428c83d3d951bf3 by Olli Lupton : stream_executor::StreamExecutor::HostMemoryDeallocate: pass size -- f50c9acce27aae4931c41ea1a3c1e9fb866c2d14 by Olli Lupton : GpuExecutor::HostMemory[De]Allocate: NUMA-aware In the CUDA executor allocate host memory that is close to the device. -- b8d9927e16ddb249fe35e5ef19e48464726c23a5 by Olli Lupton : bazel: enable numa-aware by default (FIXME?) -- d07182fc45647ad50f3480587ea0a70f8895e423 by Olli Lupton : GpuExecutor::HostMemory[De]Allocate: improve error handling -- 42f930bdc807a5947cf928afb15a44dea9b81f3f by Olli Lupton : Add unit test for NUMA-aware allocation -- dc6c68e1252fb98cb4d8b4be4dc3833a6d78494c by Olli Lupton : workaround failure on platforms that cannot detect numa domains Merging this change closes #15216 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15216 from olupton:numa-pinning dc6c68e1252fb98cb4d8b4be4dc3833a6d78494c PiperOrigin-RevId: 657672722 --- xla/backends/interpreter/executor.h | 2 +- xla/stream_executor/cuda/BUILD | 1 + xla/stream_executor/cuda/cuda_executor.cc | 103 ++++++++++++++---- xla/stream_executor/gpu/BUILD | 2 + xla/stream_executor/gpu/gpu_executor.h | 24 ++-- xla/stream_executor/gpu/gpu_executor_test.cc | 29 +++++ xla/stream_executor/host/host_executor.h | 2 +- xla/stream_executor/host_memory_allocation.cc | 2 +- .../integrations/device_mem_allocator.h | 2 +- xla/stream_executor/mock_stream_executor.h | 3 +- xla/stream_executor/rocm/rocm_executor.cc | 14 +++ xla/stream_executor/stream_executor.h | 2 +- xla/stream_executor/tpu/tpu_executor.h | 2 +- 13 files changed, 144 insertions(+), 44 deletions(-) diff --git a/xla/backends/interpreter/executor.h b/xla/backends/interpreter/executor.h index c653fc7317b595..822537638b09cb 100644 --- a/xla/backends/interpreter/executor.h +++ b/xla/backends/interpreter/executor.h @@ -103,7 +103,7 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { uint64_t size) override { return std::make_unique(new char[size], size, this); } - void HostMemoryDeallocate(void *mem) override { + void HostMemoryDeallocate(void *mem, uint64_t size) override { delete[] static_cast(mem); } diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 48592a1a92656b..ea7dc5431f865f 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -824,6 +824,7 @@ cuda_only_cc_library( "@tsl//tsl/platform:errors", "@tsl//tsl/platform:fingerprint", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:statusor", ] + if_cuda_is_configured([":delay_kernel_cuda"]), alwayslink = True, diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 7f478df047be84..1b7df257874205 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -86,6 +86,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/logging.h" +#include "tsl/platform/numa.h" #include "tsl/platform/statusor.h" // LOG(ERROR) uses a const named ERROR, so a macro with the same name is @@ -153,6 +154,9 @@ GpuExecutor::~GpuExecutor() { } } +static std::optional TryToReadNumaNode(const std::string& pci_bus_id, + int device_ordinal); + absl::Status GpuExecutor::Init() { TF_RETURN_IF_ERROR(GpuDriver::Init()); TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal_, &device_)); @@ -160,6 +164,17 @@ absl::Status GpuExecutor::Init() { GpuDriver::CreateContext(device_ordinal_, device_, &context_)); TF_RETURN_IF_ERROR( GpuDriver::GetComputeCapability(&cc_major_, &cc_minor_, device_)); + std::optional numa_node = TryToReadNumaNode( + absl::AsciiStrToLower(GpuDriver::GetPCIBusID(device_ordinal_)), + device_ordinal_); + if (!numa_node || *numa_node < 0) { + LOG(WARNING) << "NUMA node could not be determined for device " + << device_ordinal_ + << ", host memory allocations will not be NUMA-pinned"; + numa_node_ = tsl::port::kNUMANoAffinity; + } else { + numa_node_ = *numa_node; + } return absl::OkStatus(); } @@ -586,6 +601,47 @@ void GpuExecutor::Deallocate(DeviceMemoryBase* mem) { GpuDriver::DeviceDeallocate(context_, mem->opaque()); } +// CUDA allocation/registration functions are necessary because the driver +// internally sets up buffers for DMA operations (and page locks them). There's +// no external interface for us to otherwise control these DMA settings. +absl::StatusOr> +GpuExecutor::HostMemoryAllocate(uint64_t size) { + if (numa_node_ != tsl::port::kNUMANoAffinity) { + auto* buffer = + tsl::port::NUMAMalloc(numa_node_, size, /* minimum_alignment=*/16); + if (buffer == nullptr && size > 0) { + return absl::InternalError(absl::StrFormat( + "Failed to allocate host memory of size %d pinned to NUMA node %d", + size, numa_node_)); + } + if (size > 0 && !GpuDriver::HostRegister(context_, buffer, size)) { + return absl::InternalError( + absl::StrFormat("Failed to register host memory of size %d pinned to " + "NUMA node %d with the GPU driver", + size, numa_node_)); + } + return std::make_unique(buffer, size, this); + } else { + auto* buffer = GpuDriver::HostAllocate(context_, size); + if (buffer == nullptr && size > 0) { + return absl::InternalError( + absl::StrFormat("Failed to allocate HostMemory of size %d", size)); + } + return std::make_unique(buffer, size, this); + } +} + +void GpuExecutor::HostMemoryDeallocate(void* location, uint64_t size) { + if (numa_node_ != tsl::port::kNUMANoAffinity) { + if (size > 0) { + GpuDriver::HostUnregister(context_, location); + } + tsl::port::NUMAFree(location, size); + } else { + GpuDriver::HostDeallocate(context_, location); + } +} + bool GpuExecutor::SynchronizeAllActivity() { return GpuDriver::SynchronizeContext(context_); } @@ -810,22 +866,22 @@ std::unique_ptr GpuExecutor::CreateCommandBuffer( GpuContext* GpuExecutor::gpu_context() { return context_; } // Attempts to read the NUMA node corresponding to the GPU device's PCI bus out -// of SysFS. Returns -1 if it cannot. +// of SysFS. // // For anything more complicated/prod-focused than this, you'll likely want to -// turn to gsys' topology modeling. -static int TryToReadNumaNode(const std::string& pci_bus_id, - int device_ordinal) { +// turn to gsys' topology modeling. nvmlDeviceGetMemoryAffinity could also be +// used. +static std::optional TryToReadNumaNode(const std::string& pci_bus_id, + int device_ordinal) { #if defined(PLATFORM_WINDOWS) // Windows support for NUMA is not currently implemented. Return node 0. return 0; #else VLOG(2) << "trying to read NUMA node for device ordinal: " << device_ordinal; - static const int kUnknownNumaNode = -1; if (pci_bus_id.empty()) { LOG(INFO) << "no PCI bus ID for device ordinal: " << device_ordinal; - return kUnknownNumaNode; + return std::nullopt; } std::string filename = @@ -838,7 +894,7 @@ static int TryToReadNumaNode(const std::string& pci_bus_id, if (file == nullptr) { LOG(INFO) << "could not open file to read NUMA node: " << filename << "\nYour kernel may have been built without NUMA support."; - return kUnknownNumaNode; + return std::nullopt; } std::string content; @@ -849,17 +905,6 @@ static int TryToReadNumaNode(const std::string& pci_bus_id, int32_t value; if (absl::SimpleAtoi(content, &value)) { - if (value < 0) { // See http://b/18228951 for details on this path. - LOG(INFO) << "successful NUMA node read from SysFS had negative value (" - << value - << "), but there must be at least one NUMA node" - ", so returning NUMA node zero." - " See more at " - "https://github.com/torvalds/linux/blob/v6.0/Documentation/" - "ABI/testing/sysfs-bus-pci#L344-L355"; - fclose(file); - return 0; - } fclose(file); return value; } @@ -869,7 +914,7 @@ static int TryToReadNumaNode(const std::string& pci_bus_id, << content; fclose(file); - return kUnknownNumaNode; + return std::nullopt; #endif } @@ -901,8 +946,24 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { builder.set_pci_bus_id(pci_bus_id); // Read the NUMA node corresponding to the PCI bus ID out of sysfs. - int numa_node = TryToReadNumaNode(pci_bus_id, device_ordinal); - builder.set_numa_node(numa_node); + std::optional numa_node = + TryToReadNumaNode(pci_bus_id, device_ordinal); + if (numa_node.has_value()) { + if (*numa_node < 0) { // See http://b/18228951 for details on this path. + LOG(INFO) + << "successful NUMA node read from SysFS had negative value (" + << *numa_node + << "), but there must be at least one NUMA node" + ", so returning NUMA node zero." + " See more at " + "https://github.com/torvalds/linux/blob/v6.0/Documentation/" + "ABI/testing/sysfs-bus-pci#L344-L355"; + numa_node = 0; + } + } else { + numa_node = -1; + } + builder.set_numa_node(*numa_node); } { diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 1da212056f868e..1eddab2a7426fe 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -235,6 +235,7 @@ gpu_only_cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:thread_annotations", ], ) @@ -797,6 +798,7 @@ xla_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ] + if_cuda([ diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index 13b9b944d1beb2..b15cd52b7461db 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -58,6 +58,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_common.h" +#include "tsl/platform/numa.h" #include "tsl/platform/thread_annotations.h" namespace stream_executor { @@ -111,7 +112,8 @@ class GpuExecutor : public StreamExecutorCommon { device_ordinal_(device_ordinal), cc_major_(0), cc_minor_(0), - version_(0) {} + version_(0), + numa_node_(tsl::port::kNUMANoAffinity) {} // See the corresponding StreamExecutor methods for method comments on the // following overrides. @@ -169,23 +171,10 @@ class GpuExecutor : public StreamExecutorCommon { return GpuCollectives::CollectiveMemoryDeallocate(context_, location); } - // CUDA allocation/registration functions are necessary because the driver - // internally sets up buffers for DMA operations (and page locks them). - // There's no external interface for us to otherwise control these DMA - // settings. absl::StatusOr> HostMemoryAllocate( - uint64_t size) override { - auto* buffer = GpuDriver::HostAllocate(context_, size); - if (buffer == nullptr && size > 0) { - return absl::InternalError( - absl::StrFormat("Failed to allocate HostMemory of size %d", size)); - } - return std::make_unique(buffer, size, this); - } + uint64_t size) override; - void HostMemoryDeallocate(void* location) override { - return GpuDriver::HostDeallocate(context_, location); - } + void HostMemoryDeallocate(void* location, uint64_t size) override; absl::StatusOr GetPointerMemorySpace(const void* ptr) override { return GpuDriver::GetPointerMemorySpace( @@ -386,6 +375,9 @@ class GpuExecutor : public StreamExecutorCommon { // GPU ISA version for device_. int version_; + // NUMA node for device_. + int numa_node_; + // Type erased XLA specific state attached to GpuExecutor. Object xla_state_; diff --git a/xla/stream_executor/gpu/gpu_executor_test.cc b/xla/stream_executor/gpu/gpu_executor_test.cc index c3c67bc03d8884..9ac7be1a2c2210 100644 --- a/xla/stream_executor/gpu/gpu_executor_test.cc +++ b/xla/stream_executor/gpu/gpu_executor_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/numa.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -54,4 +55,32 @@ TEST_F(GetPointerMemorySpaceTest, Device) { executor->Deallocate(&mem); } +using HostMemoryAllocateTest = GpuExecutorTest; + +TEST_F(HostMemoryAllocateTest, Numa) { + Platform* platform = GetPlatform(); + const uint64_t kSize = 1024; + const int num_devices = platform->VisibleDeviceCount(); + for (int device = 0; device < num_devices; ++device) { + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(device)); + ASSERT_TRUE(executor); + TF_ASSERT_OK_AND_ASSIGN(auto device_desc, + executor->CreateDeviceDescription()); + ASSERT_TRUE(device_desc); + TF_ASSERT_OK_AND_ASSIGN(auto host_ptr, executor->HostMemoryAllocate(kSize)); + ASSERT_TRUE(host_ptr); + EXPECT_NE(host_ptr->opaque(), nullptr); + const int numa_node = tsl::port::NUMAGetMemAffinity(host_ptr->opaque()); + if (numa_node == tsl::port::kNUMANoAffinity) { + // Could be because `executor` could not determine its own NUMA node, in + // which case numa_node() will be -1 or 0, depending on the failure mode. + EXPECT_LE(device_desc->numa_node(), 0); + EXPECT_GE(device_desc->numa_node(), -1); + } else { + EXPECT_EQ(device_desc->numa_node(), numa_node); + } + } +} + } // namespace stream_executor diff --git a/xla/stream_executor/host/host_executor.h b/xla/stream_executor/host/host_executor.h index 4e2a2230ffbd4c..478ab2778cbe9e 100644 --- a/xla/stream_executor/host/host_executor.h +++ b/xla/stream_executor/host/host_executor.h @@ -84,7 +84,7 @@ class HostExecutor : public StreamExecutorCommon { uint64_t size) override { return std::make_unique(new char[size], size, this); } - void HostMemoryDeallocate(void* mem) override { + void HostMemoryDeallocate(void* mem, uint64_t size) override { delete[] static_cast(mem); } diff --git a/xla/stream_executor/host_memory_allocation.cc b/xla/stream_executor/host_memory_allocation.cc index e77c5e8c69475c..9772396b9cc61e 100644 --- a/xla/stream_executor/host_memory_allocation.cc +++ b/xla/stream_executor/host_memory_allocation.cc @@ -27,7 +27,7 @@ HostMemoryAllocation::HostMemoryAllocation(void* ptr, uint64_t size, HostMemoryAllocation::~HostMemoryAllocation() { if (ptr_ != nullptr && executor_ != nullptr) { - executor_->HostMemoryDeallocate(ptr_); + executor_->HostMemoryDeallocate(ptr_, size_); } } diff --git a/xla/stream_executor/integrations/device_mem_allocator.h b/xla/stream_executor/integrations/device_mem_allocator.h index 736b62e051314a..8b31f8b6e5b291 100644 --- a/xla/stream_executor/integrations/device_mem_allocator.h +++ b/xla/stream_executor/integrations/device_mem_allocator.h @@ -82,7 +82,7 @@ class DeviceMemAllocator : public tsl::SubAllocator { auto status = stream_exec_->CollectiveMemoryDeallocate(ptr); CHECK(status.ok()) << status.message(); } else if (memory_type_ == MemoryType::kHost) { - stream_exec_->HostMemoryDeallocate(ptr); + stream_exec_->HostMemoryDeallocate(ptr, num_bytes); } else { DeviceMemoryBase device_ptr(ptr); stream_exec_->Deallocate(&device_ptr); diff --git a/xla/stream_executor/mock_stream_executor.h b/xla/stream_executor/mock_stream_executor.h index f58a553f9ebdd8..9748dcbf4e8abb 100644 --- a/xla/stream_executor/mock_stream_executor.h +++ b/xla/stream_executor/mock_stream_executor.h @@ -92,7 +92,8 @@ class MockStreamExecutor : public StreamExecutor { (override)); MOCK_METHOD(absl::StatusOr>, HostMemoryAllocate, (uint64_t size), (override)); - MOCK_METHOD(void, HostMemoryDeallocate, (void* mem), (override)); + MOCK_METHOD(void, HostMemoryDeallocate, (void* mem, uint64_t size), + (override)); MOCK_METHOD(bool, SynchronizeAllActivity, (), (override)); MOCK_METHOD(absl::Status, SynchronousMemZero, (DeviceMemoryBase * location, uint64_t size), (override)); diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 19a367a37ec27a..45f08edbac9abc 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -459,6 +459,20 @@ void GpuExecutor::Deallocate(DeviceMemoryBase* mem) { GpuDriver::DeviceDeallocate(context_, mem->opaque()); } +absl::StatusOr> +GpuExecutor::HostMemoryAllocate(uint64_t size) { + auto* buffer = GpuDriver::HostAllocate(context_, size); + if (buffer == nullptr && size > 0) { + return absl::InternalError( + absl::StrFormat("Failed to allocate HostMemory of size %d", size)); + } + return std::make_unique(buffer, size, this); +} + +void GpuExecutor::HostMemoryDeallocate(void* location, uint64_t size) { + return GpuDriver::HostDeallocate(context_, location); +} + bool GpuExecutor::SynchronizeAllActivity() { return GpuDriver::SynchronizeContext(context_); } diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index 53c7ab9d33a08a..9b69a303e92f1b 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -207,7 +207,7 @@ class StreamExecutor { uint64_t size) = 0; // Deallocates a region of host memory allocated by HostMemoryAllocate(). - virtual void HostMemoryDeallocate(void* mem) = 0; + virtual void HostMemoryDeallocate(void* mem, uint64_t size) = 0; // Returns the memory space of the given pointer. virtual absl::StatusOr GetPointerMemorySpace(const void* ptr) { diff --git a/xla/stream_executor/tpu/tpu_executor.h b/xla/stream_executor/tpu/tpu_executor.h index d5b719787f4e6a..88ef4618424732 100644 --- a/xla/stream_executor/tpu/tpu_executor.h +++ b/xla/stream_executor/tpu/tpu_executor.h @@ -137,7 +137,7 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { uint64_t size) override { LOG(FATAL) << "not yet implemented"; } - void HostMemoryDeallocate(void* mem) override { + void HostMemoryDeallocate(void* mem, uint64_t size) override { LOG(FATAL) << "not yet implemented"; } absl::Status SynchronousMemZero(DeviceMemoryBase* location, From e0fe89830ca8a499cb3a1d4d808f8a21d7a92d49 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 11:45:53 -0700 Subject: [PATCH 276/376] Add the name of the first task to the barrier time out error message. PiperOrigin-RevId: 657674687 --- .../coordination/coordination_service.cc | 27 ++++++++++++++++--- .../coordination/coordination_service_test.cc | 21 +++++++++++---- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/xla/tsl/distributed_runtime/coordination/coordination_service.cc index c70dde5e12d3b3..dd53c8018a6186 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -179,6 +179,9 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { CoordinatedTaskEqual> tasks_at_barrier; std::vector done_callbacks; + // Specifies the task that initiated the barrier (the first task to call the + // barrier). + CoordinatedTask initiating_task; }; void PassBarrier(std::string_view barrier_id, absl::Status result, BarrierState* barrier) @@ -532,10 +535,27 @@ void CoordinationServiceStandaloneImpl::StartCheckStaleness() { } } } + std::string error_message = absl::StrFormat( + "Barrier timed out. This usually happens because a task " + "triggered the barrier unexpectedly early, or some tasks are " + "too slow. Please look at the other task logs to debug " + "further. Barrier_id: %s. The first task at the barrier: " + "%s. ", + barrier_id, GetTaskName(barrier->initiating_task)); + if (pending_task_count > kPendingTaskLogLimit) { + absl::StrAppend(&error_message, + "Too many tasks have timed out. The first ", + kPendingTaskLogLimit, + " timed out task names:\n", pending_tasks); + } else { + absl::StrAppend( + &error_message, + "Total Number of tasks already at the barrier: ", + barrier->tasks_at_barrier.size() - pending_task_count, + ". Timed out task names:\n%s", pending_tasks); + } const absl::Status error = MakeCoordinationError( - absl::DeadlineExceededError(absl::StrCat( - "Barrier timed out. Barrier_id: ", barrier_id, - ". Timed out task names:\n", pending_tasks))); + absl::DeadlineExceededError(error_message)); PassBarrier(barrier_id, error, barrier); } } @@ -1248,6 +1268,7 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( if (inserted) { // Initialize barrier state. barrier->passed = false; + barrier->initiating_task = task; // Assume barrier is for entire cluster if no tasks are specified. if (participating_tasks.empty()) { for (const auto& task_state : cluster_state_) { diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 6133d19ef72380..b9b9bbf75215f4 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -1169,9 +1169,15 @@ TEST_F(CoordinationBarrierTest, BarrierByNonClusterTask) { TEST_F(CoordinationBarrierTest, BarrierTimeout) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(1); - absl::Status barrier_status_0; - absl::Notification n_0; + absl::Status barrier_status_0, barrier_status_1; + absl::Notification n_0, n_1; + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(1), + /*participating_tasks=*/{}, [&barrier_status_1, &n_1](absl::Status s) { + barrier_status_1 = s; + n_1.Notify(); + }); GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{}, [&barrier_status_0, &n_0](absl::Status s) { @@ -1181,13 +1187,18 @@ TEST_F(CoordinationBarrierTest, BarrierTimeout) { // Block until user-specified timeout. n_0.WaitForNotification(); + n_1.WaitForNotification(); + + // All barrier calls should fail with the same error. + EXPECT_EQ(barrier_status_0, barrier_status_1); EXPECT_TRUE(absl::IsDeadlineExceeded(barrier_status_0)); EXPECT_FALSE( absl::StrContains(barrier_status_0.message(), GetTaskName(GetTask(0)))); EXPECT_TRUE( - absl::StrContains(barrier_status_0.message(), GetTaskName(GetTask(1)))); - EXPECT_TRUE( - absl::StrContains(barrier_status_0.message(), GetTaskName(GetTask(2)))); + absl::StrContains(barrier_status_0.message(), + GetTaskName(GetTask(1)))); // First task at barrier. + EXPECT_TRUE(absl::StrContains(barrier_status_0.message(), + GetTaskName(GetTask(2)))); // Timed-out task. } TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) { From 2460a7a9a654dc0ea266d56137a8544e20ae7c9b Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 30 Jul 2024 12:39:40 -0700 Subject: [PATCH 277/376] Move StreamExecutor::Submit method to GpuCommandBuffer. The Submit method didn't need anything from any StreamExecutor class, and is fully-implementable as part of GpuCommandBuffer. PiperOrigin-RevId: 657694166 --- .../gpu/runtime/command_buffer_cmd_test.cc | 6 +- .../gpu/runtime/command_buffer_thunk.cc | 2 +- xla/stream_executor/command_buffer.h | 5 ++ xla/stream_executor/cuda/cuda_executor.cc | 13 --- xla/stream_executor/gpu/gpu_command_buffer.cc | 11 +++ xla/stream_executor/gpu/gpu_command_buffer.h | 1 + .../gpu/gpu_command_buffer_test.cc | 81 ++++++++++--------- xla/stream_executor/gpu/gpu_executor.h | 3 - xla/stream_executor/mock_stream_executor.h | 2 - xla/stream_executor/rocm/rocm_executor.cc | 13 --- xla/stream_executor/stream_executor.h | 6 -- 11 files changed, 62 insertions(+), 81 deletions(-) diff --git a/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/xla/service/gpu/runtime/command_buffer_cmd_test.cc index 22d586775e5ab7..8bdaf6a159eb51 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd_test.cc +++ b/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -235,7 +235,7 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(stream.get(), *command_buffer)); + TF_ASSERT_OK(command_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 0); @@ -306,7 +306,7 @@ TEST(CommandBufferCmdTest, BarrierCmd) { TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(stream.get(), *command_buffer)); + TF_ASSERT_OK(command_buffer->Submit(stream.get())); // Copy data back to host, correct executor order should populate all buffers // with expected value. @@ -384,7 +384,7 @@ TEST(CommandBufferCmdTest, LaunchCmd) { TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(stream.get(), *command_buffer)); + TF_ASSERT_OK(command_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 0); diff --git a/xla/service/gpu/runtime/command_buffer_thunk.cc b/xla/service/gpu/runtime/command_buffer_thunk.cc index 42d14071fcf4e1..c7a7117a86c6b4 100644 --- a/xla/service/gpu/runtime/command_buffer_thunk.cc +++ b/xla/service/gpu/runtime/command_buffer_thunk.cc @@ -256,7 +256,7 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { {"num_executions", cmd_buffer->num_executions}}); }); - return executor->Submit(params.stream, *cmd_buffer->command_buffer); + return cmd_buffer->command_buffer->Submit(params.stream); } absl::StatusOr> diff --git a/xla/stream_executor/command_buffer.h b/xla/stream_executor/command_buffer.h index 5cb39e857f7fbb..2b92b504f2059a 100644 --- a/xla/stream_executor/command_buffer.h +++ b/xla/stream_executor/command_buffer.h @@ -327,6 +327,11 @@ class CommandBuffer { return While(kDefaulExecutionScope, pred, cond_builder, body_builder); } + // Submits the command buffer for execution. + virtual absl::Status Submit(Stream* stream) { + return absl::UnimplementedError("Not implemented for this command buffer."); + } + //--------------------------------------------------------------------------// // Command buffer state management API //--------------------------------------------------------------------------// diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 1b7df257874205..a8118b27b83c98 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -569,19 +569,6 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, return absl::InternalError("Unsupported kernel arguments type"); } -absl::Status GpuExecutor::Submit(Stream* stream, - const CommandBuffer& command_buffer) { - if (command_buffer.mode() != CommandBuffer::Mode::kPrimary) { - return absl::InvalidArgumentError( - "Can't submit non-primary command buffer for execution"); - } - - auto exec = GpuCommandBuffer::Cast(&command_buffer)->executable(); - VLOG(3) << "Launch command buffer executable graph " << exec - << " on a stream: " << stream; - return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream)); -} - DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == 1) { auto result = GpuCollectives::CollectiveMemoryAllocate(context_, size); diff --git a/xla/stream_executor/gpu/gpu_command_buffer.cc b/xla/stream_executor/gpu/gpu_command_buffer.cc index a0334695552915..2fdc8118bcd787 100644 --- a/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -1073,4 +1073,15 @@ GpuCommandBuffer::barriers(ExecutionScopeId id) const { return {}; } +absl::Status GpuCommandBuffer::Submit(Stream* stream) { + if (mode_ != CommandBuffer::Mode::kPrimary) { + return absl::InvalidArgumentError( + "Can't submit non-primary command buffer for execution"); + } + + VLOG(3) << "Launch command buffer executable graph " << exec_ + << " on a stream: " << stream; + return GpuDriver::GraphLaunch(exec_, AsGpuStreamValue(stream)); +} + } // namespace stream_executor::gpu diff --git a/xla/stream_executor/gpu/gpu_command_buffer.h b/xla/stream_executor/gpu/gpu_command_buffer.h index 2808fe6364c047..0b33d340363e24 100644 --- a/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/xla/stream_executor/gpu/gpu_command_buffer.h @@ -123,6 +123,7 @@ class GpuCommandBuffer : public CommandBuffer { absl::Status Finalize() override; absl::Status Update() override; + absl::Status Submit(Stream* stream) override; GpuGraphExecHandle executable() const { return exec_; } GpuGraphHandle graph() const { return graph_; } diff --git a/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/xla/stream_executor/gpu/gpu_command_buffer_test.cc index ef31559eefc5bd..306556c2d736e8 100644 --- a/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -133,7 +133,7 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { TF_ASSERT_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), a, b, c)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `c` data back to host. std::vector dst(4, 42); @@ -151,7 +151,7 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { TF_ASSERT_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), a, b, d)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); @@ -203,15 +203,16 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { KernelArgsDeviceMemoryArray args({a, b, c}, 0); // Create a command buffer by tracing kernel launch operations. - auto cmd_buffer = TraceCommandBufferFactory::Create( - executor, - [&](Stream* stream) { - return stream->Launch(ThreadDim(), BlockDim(4), *add, args); - }, - primary); + TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, TraceCommandBufferFactory::Create( + executor, + [&](Stream* stream) { + return stream->Launch( + ThreadDim(), BlockDim(4), + *add, args); + }, + primary)); - TF_ASSERT_OK(cmd_buffer.status()); - TF_ASSERT_OK(executor->Submit(stream.get(), **cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy data back to host. std::vector dst(4, 42); @@ -249,7 +250,7 @@ TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) { TF_ASSERT_OK(primary_cmd->AddNestedCommandBuffer(*nested_cmd)); TF_ASSERT_OK(primary_cmd->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *primary_cmd)); + TF_ASSERT_OK(primary_cmd->Submit(stream.get())); // Copy `c` data back to host. std::vector dst(4, 42); @@ -270,7 +271,7 @@ TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) { TF_ASSERT_OK(primary_cmd->AddNestedCommandBuffer(*nested_cmd)); TF_ASSERT_OK(primary_cmd->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *primary_cmd)); + TF_ASSERT_OK(primary_cmd->Submit(stream.get())); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); @@ -298,7 +299,7 @@ TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) { TF_ASSERT_OK(cmd_buffer->MemcpyDeviceToDevice(&b, a, byte_length)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 0); @@ -315,7 +316,7 @@ TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) { // Clear destination to test that command buffer actually copied memory. TF_ASSERT_OK(stream->Memset32(&a, 0, byte_length)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `a` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -339,7 +340,7 @@ TEST(GpuCommandBufferTest, Memset) { TF_ASSERT_OK(cmd_buffer->Memset(&a, uint32_t{42}, length)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `a` data back to host. std::vector dst(4, 0); @@ -353,7 +354,7 @@ TEST(GpuCommandBufferTest, Memset) { TF_ASSERT_OK(cmd_buffer->Memset(&a, uint32_t{43}, length)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -408,7 +409,7 @@ TEST(GpuCommandBufferTest, Barriers) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44, 45, 46, 47}; ASSERT_EQ(transfer_buffers(), expected); @@ -445,7 +446,7 @@ TEST(GpuCommandBufferTest, Barriers) { // Update command buffer to use a new bit pattern. TF_ASSERT_OK(cmd_buffer->Update()); TF_ASSERT_OK(record(cmd_buffer.get(), 43)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {43, 44, 45, 46, 47, 48}; ASSERT_EQ(transfer_buffers(), expected); @@ -488,7 +489,7 @@ TEST(GpuCommandBufferTest, IndependentExecutionScopes) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44, 45}; ASSERT_EQ(transfer_buffers(), expected); @@ -515,7 +516,7 @@ TEST(GpuCommandBufferTest, IndependentExecutionScopes) { // Update command buffer to use a new bit pattern. TF_ASSERT_OK(cmd_buffer->Update()); TF_ASSERT_OK(record(cmd_buffer.get(), 43)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {43, 44, 45, 46}; ASSERT_EQ(transfer_buffers(), expected); @@ -562,7 +563,7 @@ TEST(GpuCommandBufferTest, ExecutionScopeBarriers) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44, 45, 46, 47, 48}; ASSERT_EQ(transfer_buffers(), expected); @@ -607,7 +608,7 @@ TEST(GpuCommandBufferTest, ExecutionScopeBarriers) { // Update command buffer to use a new bit pattern. TF_ASSERT_OK(cmd_buffer->Update()); TF_ASSERT_OK(record(cmd_buffer.get(), 43)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {43, 44, 45, 46, 47, 48, 49}; ASSERT_EQ(transfer_buffers(), expected); @@ -652,7 +653,7 @@ TEST(GpuCommandBufferTest, ExecutionScopeOneDirectionalBarriers) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44, 45, 46, 47}; ASSERT_EQ(transfer_buffers(), expected); @@ -683,7 +684,7 @@ TEST(GpuCommandBufferTest, ExecutionScopeOneDirectionalBarriers) { // Update command buffer to use a new bit pattern. TF_ASSERT_OK(cmd_buffer->Update()); TF_ASSERT_OK(record(cmd_buffer.get(), 43)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {43, 44, 45, 46, 47, 48}; ASSERT_EQ(transfer_buffers(), expected); @@ -728,7 +729,7 @@ TEST(GpuCommandBufferTest, ConditionalIf) { TF_ASSERT_OK(cmd_buffer->If(pred, then_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `c` data back to host. std::vector dst(4, 42); @@ -744,7 +745,7 @@ TEST(GpuCommandBufferTest, ConditionalIf) { // Submit the same command buffer, but this time it should not execute // conditional branch as conditional handle should be updated to false. - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); std::vector zeroes = {0, 0, 0, 0}; @@ -767,7 +768,7 @@ TEST(GpuCommandBufferTest, ConditionalIf) { TF_ASSERT_OK(cmd_buffer->If(pred, then_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); @@ -825,7 +826,7 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { TF_ASSERT_OK(cmd_buffer->IfElse(pred, then_builder, else_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); // Copy `c` data back to host. @@ -841,7 +842,7 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { // Submit the same command buffer, but this time it should execute `else` // branch and multiply inputs. - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); @@ -862,7 +863,7 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { TF_ASSERT_OK(cmd_buffer->IfElse(pred, then_builder, else_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); // Copy `d` data back to host. @@ -920,7 +921,7 @@ TEST(GpuCommandBufferTest, ConditionalCase) { TF_ASSERT_OK(cmd_buffer->Case(index, {branch0, branch1})); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); // Copy `c` data back to host. @@ -934,7 +935,7 @@ TEST(GpuCommandBufferTest, ConditionalCase) { TF_ASSERT_OK(stream->Memset32(&index, 1, sizeof(int32_t))); // Submit the same command buffer, but this time it should multiply inputs. - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); @@ -944,7 +945,7 @@ TEST(GpuCommandBufferTest, ConditionalCase) { // Set index to `-1` (out of bound index value). TF_ASSERT_OK(stream->Memset32(&index, -1, sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); @@ -953,7 +954,7 @@ TEST(GpuCommandBufferTest, ConditionalCase) { // Set index to `2` (out of bound index value). TF_ASSERT_OK(stream->Memset32(&index, 2, sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); @@ -999,7 +1000,7 @@ TEST(GpuCommandBufferTest, ConditionalFor) { TF_ASSERT_OK(cmd_buffer->For(num_iters, loop_counter, body_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 42); @@ -1066,7 +1067,7 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { TF_ASSERT_OK(cmd_buffer->While(pred, cond_builder, body_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 42); @@ -1131,7 +1132,7 @@ TEST(GpuCommandBufferTest, ConditionalIfInExecutionScope) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44}; ASSERT_EQ(transfer_buffers(), expected); @@ -1165,7 +1166,7 @@ TEST(GpuCommandBufferTest, ConditionalIfInExecutionScope) { constexpr bool kFalse = false; TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1)); TF_ASSERT_OK(stream->MemZero(&buffers[2], sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {42, 43, 0}; ASSERT_EQ(transfer_buffers(), expected); @@ -1232,7 +1233,7 @@ TEST(GpuCommandBufferTest, ConditionalWhileInExecutionScope) { // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42, 10)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `b` and `c` data back to host. int32_t b_dst, c_dst; @@ -1265,7 +1266,7 @@ TEST(GpuCommandBufferTest, ConditionalWhileInExecutionScope) { TF_ASSERT_OK(stream->MemZero(&loop_counter, sizeof(int32_t))); TF_ASSERT_OK(stream->MemZero(&b, sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->Memcpy(&b_dst, b, sizeof(int32_t))); TF_ASSERT_OK(stream->Memcpy(&c_dst, c, sizeof(int32_t))); diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index b15cd52b7461db..bf1027d249ab58 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -148,9 +148,6 @@ class GpuExecutor : public StreamExecutorCommon { const ClusterDim& cluster_dims, const Kernel& kernel, const KernelArgs& args) override; - absl::Status Submit(Stream* stream, - const CommandBuffer& command_buffer) override; - DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; void Deallocate(DeviceMemoryBase* mem) override; diff --git a/xla/stream_executor/mock_stream_executor.h b/xla/stream_executor/mock_stream_executor.h index 9748dcbf4e8abb..2655d48833b0ec 100644 --- a/xla/stream_executor/mock_stream_executor.h +++ b/xla/stream_executor/mock_stream_executor.h @@ -78,8 +78,6 @@ class MockStreamExecutor : public StreamExecutor { const BlockDim& block_dims, const ClusterDim& cluster_dims, const Kernel& k, const KernelArgs& args), (override)); - MOCK_METHOD(absl::Status, Submit, - (Stream * stream, const CommandBuffer& command_buffer)); MOCK_METHOD(void, UnloadKernel, (const Kernel* kernel), (override)); MOCK_METHOD(DeviceMemoryBase, Allocate, (uint64_t size, int64_t memory_space), (override)); diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 45f08edbac9abc..d879fbbbb0aae4 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -384,19 +384,6 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, return Launch(stream, thread_dims, block_dims, kernel, args); } -absl::Status GpuExecutor::Submit(Stream* stream, - const CommandBuffer& command_buffer) { - if (command_buffer.mode() != CommandBuffer::Mode::kPrimary) { - return absl::InvalidArgumentError( - "Can't submit non-primary command buffer for execution"); - } - - auto exec = GpuCommandBuffer::Cast(&command_buffer)->executable(); - VLOG(3) << "Launch command buffer execuable graph " << exec - << " on a stream: " << stream; - return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream)); -} - absl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle) { // In GpuExecutor we store the pointer to the HSACO binary as diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index 9b69a303e92f1b..a634dc0f901b74 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -156,12 +156,6 @@ class StreamExecutor { return absl::UnimplementedError("Not Implemented"); } - // Submits command buffer for execution to the underlying platform driver. - virtual absl::Status Submit(Stream* stream, - const CommandBuffer& command_buffer) { - return absl::UnimplementedError("Not Implemented"); - } - // Releases any state associated with the previously loaded kernel. virtual void UnloadKernel(const Kernel* kernel) {} From d284db02bae6533f4c3adb64b4a35797a2b651a1 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Tue, 30 Jul 2024 12:40:28 -0700 Subject: [PATCH 278/376] Move `eup_version_` to `ExhaustiveOpTestBase` from `Exhaustive32BitOrLessUnaryTest` PiperOrigin-RevId: 657694417 --- .../exhaustive/exhaustive_op_test_utils.cc | 4 +++ .../exhaustive/exhaustive_op_test_utils.h | 34 ++++++++++++++++--- xla/tests/exhaustive/exhaustive_test_main.cc | 9 ----- .../exhaustive_unary_f32_or_smaller_test.cc | 25 ++------------ 4 files changed, 35 insertions(+), 37 deletions(-) diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/xla/tests/exhaustive/exhaustive_op_test_utils.cc index 17964eccfa32fd..52bdda0c6a3278 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -40,6 +40,10 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { +int eup_version = 0; + +int GetEupVersion() { return eup_version; } + bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); } bool IsSubnormalReal(xla::complex128 value) { diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.h b/xla/tests/exhaustive/exhaustive_op_test_utils.h index 5504f4f9e3f5ba..223ae2704e6b68 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -55,9 +55,14 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { +// Access this through GetEupVersion. +extern int eup_version; + +// Get the TPU EUP version (if it was provided). +int GetEupVersion(); + // Determines if the real component of the complex number is subnormal (either // sign). -// Determines if the real component of the complex number is subnormal. // // See also IsSubnormal to check if either component is subnormal. bool IsSubnormalReal(xla::complex64); @@ -246,7 +251,9 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { using OutputRangeCheck = std::function; explicit ExhaustiveOpTestBase() - : ty_(T), platform_(client_->platform()->Name()) { + : ty_(T), + platform_(client_->platform()->Name()), + eup_version_(xla::exhaustive_op_test::GetEupVersion()) { SetFastMathDisabled(true); // Run all HLO passes. In particular, constant folding is disabled by @@ -366,6 +373,20 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { const std::string& Platform() { return platform_; } + bool IsGpu(const std::string& platform) const { return platform == "CUDA"; } + bool IsCpu(const std::string& platform) const { return platform == "Host"; } + bool IsTpu(const std::string& platform) const { + return !IsGpu(platform) && !IsCpu(platform); + } + + int EupVersion() const { return eup_version_; } + bool IsPreV5Tpu(const std::string& platform) const { + return IsTpu(platform) && eup_version_ < 2; + } + bool IsPreV6Tpu(const std::string& platform) const { + return IsTpu(platform) && eup_version_ < 3; + } + // Returns the number of elements in each input literal. virtual int64_t GetInputSize() = 0; @@ -590,9 +611,12 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // The platform under test. const std::string platform_; - // Testing will ignore inputs for which known_incorrect_fn_ returns true. The - // argument to the function is the raw bits for the data being test, zero - // extended to 64 bits if the data type is less than 64 bits. + // Version of the EUP for a TPU target. Only relevant for TPU platforms. + const int eup_version_; + + // Testing will ignore inputs for which known_incorrect_fn_ returns true. + // The argument to the function is the raw bits for the data being test, + // zero extended to 64 bits if the data type is less than 64 bits. std::function known_incorrect_fn_; // If true, allows denormals to be flushed to non-sign-preserving 0. diff --git a/xla/tests/exhaustive/exhaustive_test_main.cc b/xla/tests/exhaustive/exhaustive_test_main.cc index 88a9befba9c74e..70588bc8e8a120 100644 --- a/xla/tests/exhaustive/exhaustive_test_main.cc +++ b/xla/tests/exhaustive/exhaustive_test_main.cc @@ -20,15 +20,6 @@ limitations under the License. #include "tsl/platform/test.h" -namespace xla { -namespace exhaustive_op_test { - -static int eup_version = 0; -int GetEupVersion() { return eup_version; } - -} // namespace exhaustive_op_test -} // namespace xla - GTEST_API_ int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc b/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc index 70511ecbb46363..7a9857f927d7bc 100644 --- a/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc +++ b/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -44,8 +43,7 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { - -extern int GetEupVersion(); +namespace { using Eigen::half; @@ -194,28 +192,10 @@ template class Exhaustive32BitOrLessUnaryTest : public ExhaustiveUnaryTest, public ::testing::WithParamInterface> { - public: - public: - Exhaustive32BitOrLessUnaryTest() - : eup_version_(xla::exhaustive_op_test::GetEupVersion()) {} - public: // Sets error parameters appropriately for testing tan. void SetParamsForTan(); - bool IsGpu(const std::string& platform) const { return platform == "CUDA"; } - bool IsCpu(const std::string& platform) const { return platform == "Host"; } - bool IsTpu(const std::string& platform) const { - return !IsGpu(platform) && !IsCpu(platform); - } - int EupVersion() const { return eup_version_; } - bool IsPreV5Tpu(const std::string& platform) const { - return IsTpu(platform) && eup_version_ < 2; - } - bool IsPreV6Tpu(const std::string& platform) const { - return IsTpu(platform) && eup_version_ < 3; - } - protected: using typename ExhaustiveUnaryTest::NativeT; @@ -248,8 +228,6 @@ class Exhaustive32BitOrLessUnaryTest this->ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); } } - - const int eup_version_; }; using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; @@ -740,5 +718,6 @@ INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, ::testing::Values(std::make_pair(0, 1 << 16))); #endif +} // namespace } // namespace exhaustive_op_test } // namespace xla From ead40bb4785d7a5ba19feaa1ecd086cff3214afb Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Tue, 30 Jul 2024 12:45:25 -0700 Subject: [PATCH 279/376] =?UTF-8?q?PR=20#15446:=20Add=20PassesIncGen=20to?= =?UTF-8?q?=20ChloPasses=20in=20xla/mlir=5Fhlo/mhlo/transforms/CMakeL?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/15446 Hello! We're updating jax version (v0.4.28) in our [MLIR based compiler](https://github.com/PennyLaneAI/catalyst/). However, our CI/CD process failed when building `mlir-hlo-opt` (please see [this Github Action](https://github.com/PennyLaneAI/catalyst/actions/runs/10113602669/job/27970370343?pr=931)). We’ve tested the build both locally and on an AWS instance without encountering the error. The main difference is that local and aws environments build `mlir-hlo-opt` in parallel, while the CI/CD runner only uses a single core. We believe this issue is similar to a previous PR ([#61071](https://github.com/openxla/xla/pull/3857)), where building with a single core fails to resolve a missing dependency. Thank you very much. **Description of the Change:** This patch adds `PassesIncGen` as a dependency to `ChloPasses`. Otherwise, single-core build would fail because it cannot find `stablehlo/transforms/Passes.h.inc`. **Related GitHub Issues:** Similar to https://github.com/tensorflow/mlir-hlo/issues/68. Copybara import of the project: -- 9542a7494b2a2abb39240f68e64afa9b8b1b5573 by Tzung-Han Juang : Add PassesIncGen to ChloPasses in xla/mlir_hlo/mhlo/transforms/CMakeLists.txt Merging this change closes #15446 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15446 from tzunghanjuang:add-PassesIncGen-to-ChloPasses 9542a7494b2a2abb39240f68e64afa9b8b1b5573 PiperOrigin-RevId: 657695913 --- xla/mlir_hlo/mhlo/transforms/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index 8131c0caab9571..60afe10e64759a 100644 --- a/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -175,6 +175,7 @@ add_mlir_library(ChloPasses MLIRhlo_opsIncGen MLIRChloLegalizeToHloIncGen MLIRMhloPassIncGen + PassesIncGen LINK_COMPONENTS Core From 80dc9734392712a998ac50c31561c6f29f76acad Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 30 Jul 2024 12:51:51 -0700 Subject: [PATCH 280/376] PR #15470: Fix separator detection in name uniquer. Imported from GitHub PR https://github.com/openxla/xla/pull/15470 Name uniquer [guarantees](https://github.com/openxla/xla/blob/4228782cc8ecaa8411e988bae08aa9251507e8e9/xla/service/name_uniquer.h#L30) to return distinct names but currently does not. For the sequence of inputs like the one added in the test (a__1, a, a) it will currently return a__1, a, a__1, the last output being a duplicate of the first one. This breaks GPU kernel binary caching. [This](https://github.com/openxla/xla/blob/4228782cc8ecaa8411e988bae08aa9251507e8e9/xla/service/name_uniquer.cc#L80-L96) code tries to detect the separator, which by default is "__", double underscore, in the input string and split the input string into root and suffix, for a__1 that should be a and 1. But because the code assumes the separator length to be 1, and even the default separator has length 2, this probably almost never works. As a result in a__1 further [atoi](https://github.com/openxla/xla/blob/4228782cc8ecaa8411e988bae08aa9251507e8e9/xla/service/name_uniquer.cc#L88) tries to parse "_1" instead of "1" and fails. This leads to incorrect [registration of names](https://github.com/openxla/xla/blob/4228782cc8ecaa8411e988bae08aa9251507e8e9/xla/service/name_uniquer.cc#L98-L99) in the end. The fix is to simply take the actual separator string length instead of 1. Copybara import of the project: -- 351692b0d55d653c1ffedf10882f3e491002cbee by Ilia Sergachev : Fix separator detection in name uniquer. Merging this change closes #15470 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15470 from openxla:fix_name_uniquer 351692b0d55d653c1ffedf10882f3e491002cbee PiperOrigin-RevId: 657698114 --- xla/service/name_uniquer.cc | 4 ++-- xla/service/name_uniquer_test.cc | 15 +++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/xla/service/name_uniquer.cc b/xla/service/name_uniquer.cc index 6fb7351251b57a..124cd6f427e119 100644 --- a/xla/service/name_uniquer.cc +++ b/xla/service/name_uniquer.cc @@ -83,8 +83,8 @@ std::string NameUniquer::GetUniqueName(absl::string_view prefix) { int64_t numeric_suffix = 0; size_t separator_index = root.rfind(separator_); if (separator_index != std::string::npos && (separator_index > 0) && - (separator_index < root.size() - 1)) { - std::string after_suffix = root.substr(separator_index + 1); + (separator_index < root.size() - separator_.size())) { + std::string after_suffix = root.substr(separator_index + separator_.size()); if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { has_numeric_suffix = true; // Remove numeric suffix from root. diff --git a/xla/service/name_uniquer_test.cc b/xla/service/name_uniquer_test.cc index 6ebdfffedb73d0..64e02229d1a871 100644 --- a/xla/service/name_uniquer_test.cc +++ b/xla/service/name_uniquer_test.cc @@ -14,17 +14,12 @@ limitations under the License. ==============================================================================*/ #include "xla/service/name_uniquer.h" - -#include -#include -#include - #include "tsl/platform/test.h" namespace xla { namespace { -class NameUniquerTest : public ::testing::Test {}; +using NameUniquerTest = ::testing::Test; TEST_F(NameUniquerTest, SimpleUniquer) { NameUniquer uniquer; @@ -126,5 +121,13 @@ TEST_F(NameUniquerTest, AvoidKeywords) { EXPECT_EQ("Pred", uniquer.GetUniqueName("Pred")); } +TEST_F(NameUniquerTest, DetectSeparator) { + NameUniquer uniquer; + + EXPECT_EQ(uniquer.GetUniqueName("a__1"), "a__1"); + EXPECT_EQ(uniquer.GetUniqueName("a"), "a"); + EXPECT_EQ(uniquer.GetUniqueName("a"), "a__2"); +} + } // namespace } // namespace xla From 7480189f4f8ffd04516f3c85f71c284f0b20e579 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 30 Jul 2024 13:05:07 -0700 Subject: [PATCH 281/376] [xla:cpu] Execute while loops with known trip counts as for loops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_SelectAndScatterF32/128/process_time 490µs ± 2% 422µs ± 1% -13.73% BM_SelectAndScatterF32/256/process_time 2.00ms ± 1% 1.73ms ± 1% -13.39% BM_SelectAndScatterF32/512/process_time 8.89ms ± 3% 7.82ms ± 5% -11.98% PiperOrigin-RevId: 657702822 --- xla/service/cpu/BUILD | 1 + xla/service/cpu/cpu_compiler.cc | 5 + xla/service/cpu/runtime/BUILD | 2 + xla/service/cpu/runtime/while_thunk.cc | 222 ++++++++++++++------ xla/service/cpu/runtime/while_thunk.h | 26 ++- xla/service/cpu/runtime/while_thunk_test.cc | 46 ++++ xla/service/cpu/thunk_emitter.cc | 12 +- 7 files changed, 248 insertions(+), 66 deletions(-) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 5b0b4e2d2cd7a5..a6f4e82e433305 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -327,6 +327,7 @@ cc_library( "//xla/service:while_loop_constant_sinking", "//xla/service:while_loop_invariant_code_motion", "//xla/service:while_loop_simplifier", + "//xla/service:while_loop_trip_count_annotator", "//xla/service:zero_sized_hlo_elimination", "//xla/service/cpu/runtime:thunk", "//xla/service/llvm_ir:llvm_command_line_options", diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 024a85edfb9632..ebac5206086138 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -180,6 +180,7 @@ limitations under the License. #include "xla/service/while_loop_constant_sinking.h" #include "xla/service/while_loop_invariant_code_motion.h" #include "xla/service/while_loop_simplifier.h" +#include "xla/service/while_loop_trip_count_annotator.h" #include "xla/service/zero_sized_hlo_elimination.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -685,6 +686,10 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); + // Annotate while loops with statically known trip counts, so that at run time + // we can avoid running the loop condition computations. + pipeline.AddPass(); + // Layout assignment uses alias analysis, which requires the call graph to be // flattened. pipeline.AddPass(); diff --git a/xla/service/cpu/runtime/BUILD b/xla/service/cpu/runtime/BUILD index 7065dc4a74bdc6..5ffa84491a91a2 100644 --- a/xla/service/cpu/runtime/BUILD +++ b/xla/service/cpu/runtime/BUILD @@ -934,6 +934,8 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", diff --git a/xla/service/cpu/runtime/while_thunk.cc b/xla/service/cpu/runtime/while_thunk.cc index a5aa14a419ac4c..486a0b93e72f58 100644 --- a/xla/service/cpu/runtime/while_thunk.cc +++ b/xla/service/cpu/runtime/while_thunk.cc @@ -15,13 +15,17 @@ limitations under the License. #include "xla/service/cpu/runtime/while_thunk.h" +#include #include #include +#include #include #include "absl/base/optimization.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/buffer_allocations.h" @@ -37,81 +41,38 @@ namespace xla::cpu { absl::StatusOr> WhileThunk::Create( Info info, BufferAllocation::Slice cond_buffer, ThunkSequence cond_sequence, - ThunkSequence body_sequence) { + ThunkSequence body_sequence, std::optional trip_count) { TF_ASSIGN_OR_RETURN(ThunkExecutor cond_executor, ThunkExecutor::Create(std::move(cond_sequence))); TF_ASSIGN_OR_RETURN(ThunkExecutor body_executor, ThunkExecutor::Create(std::move(body_sequence))); return absl::WrapUnique(new WhileThunk(std::move(info), cond_buffer, std::move(cond_executor), - std::move(body_executor))); + std::move(body_executor), trip_count)); } WhileThunk::WhileThunk(Info info, BufferAllocation::Slice cond_buffer, - ThunkExecutor cond_executor, ThunkExecutor body_executor) + ThunkExecutor cond_executor, ThunkExecutor body_executor, + std::optional trip_count) : Thunk(Kind::kWhile, std::move(info)), cond_buffer_(cond_buffer), cond_executor_(std::move(cond_executor)), - body_executor_(std::move(body_executor)) {} - -tsl::AsyncValueRef WhileThunk::ExecuteAsync( - const ExecuteParams& params, tsl::AsyncValueRef dependency, - bool* condition) { - auto event = tsl::MakeConstructedAsyncValueRef(); - - // Allocate while loop iteration function on heap so we can detach its life - // time from the caller stack. - auto loop_fn = std::make_shared>(); - *loop_fn = [this, condition, ¶ms, event, - loop = loop_fn.get()](absl::Status status) { - // Dependency completed with an error. Forward it to the result event. - if (ABSL_PREDICT_FALSE(!status.ok())) { - event.SetError(std::move(status)); - return; - } - - while (*condition) { - auto body_event = body_executor_.Execute(params); - auto cond_event = body_event.FlatMap([this, ¶ms](ExecuteEvent) { - return cond_executor_.Execute(params); - }); - - // If we don't know yet wether we should execute the next iteration or - // not, attach `AndThen` continuation to the `cond_event`. - if (!cond_event.IsAvailable()) { - cond_event.AndThen( - [loop](absl::Status status) { (*loop)(std::move(status)); }); - return; - } - - // Immediately forward error to the caller. - if (ABSL_PREDICT_FALSE(cond_event.IsError())) { - event.SetError(cond_event.GetError()); - return; - } - - // At this point `*condition` should have been updated and we may continue - // executing the while loop in the current thread. - DCHECK(cond_event.IsConcrete()); - } - - // Successfully completed while loop iterations. - event.SetStateConcrete(); - }; - - // Kick-off loop execution once dependency event is available. - dependency.AndThen(*loop_fn); - - // Keep `loop_fn` alive until the end of the while loop execution. - event.AndThen([loop_fn = std::move(loop_fn)]() {}); - - return event; -} + body_executor_(std::move(body_executor)), + trip_count_(trip_count) {} tsl::AsyncValueRef WhileThunk::Execute( const ExecuteParams& params) { tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + VLOG(3) << absl::StreamFormat( + "While: #trip_count=%s", + trip_count_.has_value() ? absl::StrCat(*trip_count_) : "unknown"); + + // Most of the while loops in XLA have statically known trip count. + if (ABSL_PREDICT_TRUE(trip_count_.has_value())) { + return ExecuteForLoop(params, *trip_count_); + } + const BufferAllocations* allocations = params.buffer_allocations; se::DeviceMemoryBase cond_data; @@ -122,14 +83,42 @@ tsl::AsyncValueRef WhileThunk::Execute( } bool* condition = reinterpret_cast(cond_data.opaque()); + return ExecuteWhileLoop(params, condition); +} + +tsl::AsyncValueRef WhileThunk::ExecuteForLoop( + const ExecuteParams& params, int64_t trip_count) { + for (int64_t loop_counter = 0; loop_counter < trip_count; ++loop_counter) { + auto body_event = body_executor_.Execute(params); + // If loop iteration has not completed yet, switch to async execution mode + // using `body_event` as a dependency and continue the loop iteration + // starting from `loop_counter + 1`. + if (ABSL_PREDICT_FALSE(!body_event.IsAvailable())) { + return ExecuteAsyncForLoop(params, std::move(body_event), + loop_counter + 1, trip_count); + } + + if (ABSL_PREDICT_FALSE(body_event.IsError())) { + return body_event.GetError(); + } + + DCHECK(body_event.IsConcrete()); + } + + // Successfully completed `trip_count` while loop iterations. + return OkExecuteEvent(); +} + +tsl::AsyncValueRef WhileThunk::ExecuteWhileLoop( + const ExecuteParams& params, bool* condition) { // Execute `cond` thunk sequence to initialize the loop condition. auto init_event = cond_executor_.Execute(params); // If we don't know if we should continue or not, switch to async execution // mode using `init_event` as a dependency. if (ABSL_PREDICT_FALSE(!init_event.IsAvailable())) { - return ExecuteAsync(params, std::move(init_event), condition); + return ExecuteAsyncWhileLoop(params, std::move(init_event), condition); } // Immediately forward error to the caller. @@ -145,10 +134,11 @@ tsl::AsyncValueRef WhileThunk::Execute( return cond_executor_.Execute(params); }); - // If we don't know if we should continue or not, switch to async execution - // mode using `cond_event` as a dependency. + // If loop iteration has not completed yet, switch to async execution mode + // using `cond_event` as a dependency and maybe continue the loop + // iteration (if `condition` is still true). if (ABSL_PREDICT_FALSE(!cond_event.IsAvailable())) { - return ExecuteAsync(params, std::move(cond_event), condition); + return ExecuteAsyncWhileLoop(params, std::move(cond_event), condition); } // Immediately forward error to the caller. @@ -165,6 +155,114 @@ tsl::AsyncValueRef WhileThunk::Execute( return OkExecuteEvent(); } +tsl::AsyncValueRef WhileThunk::ExecuteAsyncForLoop( + const ExecuteParams& params, tsl::AsyncValueRef dependency, + int64_t loop_counter, int64_t trip_count) { + auto event = tsl::MakeConstructedAsyncValueRef(); + + // Allocate while loop iteration function on heap so we can detach its life + // time from the caller stack. + auto loop_fn = std::make_shared>(); + *loop_fn = [this, trip_count, ¶ms, event, loop = loop_fn.get()]( + int64_t loop_counter, absl::Status status) { + // Dependency completed with an error. Forward it to the result event. + if (ABSL_PREDICT_FALSE(!status.ok())) { + event.SetError(std::move(status)); + return; + } + + for (; loop_counter < trip_count; ++loop_counter) { + auto body_event = body_executor_.Execute(params); + + // If loop iteration has not completed yet, continue execution + // asynchronously starting from `loop_counter + 1`. + if (!body_event.IsAvailable()) { + body_event.AndThen([loop, loop_counter](absl::Status status) { + (*loop)(loop_counter + 1, std::move(status)); + }); + return; + } + + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(body_event.IsError())) { + event.SetError(body_event.GetError()); + return; + } + + DCHECK(body_event.IsConcrete()); + } + + // Successfully completed `trip_count` while loop iterations. + event.SetStateConcrete(); + }; + + // Kick-off loop execution once dependency event is available. + dependency.AndThen([loop_counter, loop = loop_fn.get()](absl::Status status) { + (*loop)(loop_counter, std::move(status)); + }); + + // Keep `loop_fn` alive until the end of the while loop execution. + event.AndThen([loop_fn = std::move(loop_fn)]() {}); + + return event; +} + +tsl::AsyncValueRef WhileThunk::ExecuteAsyncWhileLoop( + const ExecuteParams& params, tsl::AsyncValueRef dependency, + bool* condition) { + auto event = tsl::MakeConstructedAsyncValueRef(); + + // Allocate while loop iteration function on heap so we can detach its life + // time from the caller stack. + auto loop_fn = std::make_shared>(); + *loop_fn = [this, condition, ¶ms, event, + loop = loop_fn.get()](absl::Status status) { + // Dependency completed with an error. Forward it to the result event. + if (ABSL_PREDICT_FALSE(!status.ok())) { + event.SetError(std::move(status)); + return; + } + + while (*condition) { + auto body_event = body_executor_.Execute(params); + auto cond_event = body_event.FlatMap([this, ¶ms](ExecuteEvent) { + return cond_executor_.Execute(params); + }); + + // If loop iteration has not completed yet, continue execution + // asynchronously (if `condition` is still true when it becomes ready). + if (!cond_event.IsAvailable()) { + cond_event.AndThen( + [loop](absl::Status status) { (*loop)(std::move(status)); }); + return; + } + + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(cond_event.IsError())) { + event.SetError(cond_event.GetError()); + return; + } + + // At this point `*condition` should have been updated and we may continue + // executing the while loop in the current thread. + DCHECK(cond_event.IsConcrete()); + } + + // Successfully completed while loop iterations. + event.SetStateConcrete(); + }; + + // Kick-off loop execution once dependency event is available. + dependency.AndThen([loop = loop_fn.get()](absl::Status status) { + (*loop)(std::move(status)); + }); + + // Keep `loop_fn` alive until the end of the while loop execution. + event.AndThen([loop_fn = std::move(loop_fn)]() {}); + + return event; +} + WhileThunk::BufferUses WhileThunk::buffer_uses() const { BufferUses buffer_uses = {{cond_buffer_, BufferUse::kWrite}}; diff --git a/xla/service/cpu/runtime/while_thunk.h b/xla/service/cpu/runtime/while_thunk.h index 9c5a7af272468c..e631e54842a52a 100644 --- a/xla/service/cpu/runtime/while_thunk.h +++ b/xla/service/cpu/runtime/while_thunk.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_RUNTIME_WHILE_THUNK_H_ #define XLA_SERVICE_CPU_RUNTIME_WHILE_THUNK_H_ +#include #include +#include #include "absl/status/statusor.h" #include "xla/service/buffer_assignment.h" @@ -37,7 +39,8 @@ class WhileThunk final : public Thunk { public: static absl::StatusOr> Create( Info info, BufferAllocation::Slice cond_buffer, - ThunkSequence cond_sequence, ThunkSequence body_sequence); + ThunkSequence cond_sequence, ThunkSequence body_sequence, + std::optional trip_count = std::nullopt); tsl::AsyncValueRef Execute(const ExecuteParams& params) final; @@ -46,19 +49,36 @@ class WhileThunk final : public Thunk { private: WhileThunk(Info info, BufferAllocation::Slice cond_buffer, - ThunkExecutor cond_executor, ThunkExecutor body_executor); + ThunkExecutor cond_executor, ThunkExecutor body_executor, + std::optional trip_count); + + tsl::AsyncValueRef ExecuteForLoop(const ExecuteParams& params, + int64_t trip_count); + + tsl::AsyncValueRef ExecuteWhileLoop(const ExecuteParams& params, + bool* condition); // If `cond` or `body` thunk sequence return unavailable async values, then // we execute the while loop asynchronously by chaining `Execute` calls via // `AndThen` callbacks. This execution mode adds significant overheads, so we // try to avoid it when possible and run everything in the caller thread. - tsl::AsyncValueRef ExecuteAsync( + + tsl::AsyncValueRef ExecuteAsyncForLoop( + const ExecuteParams& params, tsl::AsyncValueRef dependency, + int64_t loop_counter, int64_t trip_count); + + tsl::AsyncValueRef ExecuteAsyncWhileLoop( const ExecuteParams& params, tsl::AsyncValueRef dependency, bool* condition); BufferAllocation::Slice cond_buffer_; ThunkExecutor cond_executor_; ThunkExecutor body_executor_; + + // Statically known trip count. If available, WhileThunk::Execute will not + // execute `cond_executor_` and simply call `body_executor_` `trip_count` + // times (effectively converting while loop into a for loop). + std::optional trip_count_; }; } // namespace xla::cpu diff --git a/xla/service/cpu/runtime/while_thunk_test.cc b/xla/service/cpu/runtime/while_thunk_test.cc index 5da7202f7d9b7f..fc6a32c8bd715e 100644 --- a/xla/service/cpu/runtime/while_thunk_test.cc +++ b/xla/service/cpu/runtime/while_thunk_test.cc @@ -203,5 +203,51 @@ TEST(WhileThunkTest, NonBlockingExecute) { EXPECT_EQ(counter[0], kNumIterations); } +TEST(WhileThunkTest, NonBlockingExecuteWithTripCount) { + static constexpr size_t kNumIterations = 100; + + BufferAllocation pred_alloc(0, sizeof(char), 0); + BufferAllocation cnt_alloc(1, sizeof(int32_t), 0); + + BufferAllocation::Slice pred_slice(&pred_alloc, 0, sizeof(char)); + BufferAllocation::Slice cnt_slice(&cnt_alloc, 0, sizeof(int32_t)); + + std::vector buffers; + std::vector predicate = {false}; + std::vector counter = {0}; + + buffers.emplace_back(se::DeviceMemoryBase(predicate.data(), sizeof(char))); + buffers.emplace_back(se::DeviceMemoryBase(counter.data(), sizeof(int32_t))); + + BufferAllocations allocations(buffers); + + // We pass empty cond sequence, because we know the trip count, and check that + // predicate value is ignored (it is initialized to false) and body executed + // `kNumIterations` times. + ThunkSequence cond_sequence; + + ThunkSequence body_sequence; + body_sequence.push_back(std::make_unique(cnt_slice)); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, WhileThunk::Create( + {"while"}, pred_slice, std::move(cond_sequence), + std::move(body_sequence), /*trip_count=*/kNumIterations)); + + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "while-test", 8); + Eigen::ThreadPoolDevice device(thread_pool.AsEigenThreadPool(), + thread_pool.NumThreads()); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + params.intra_op_threadpool = &device; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + EXPECT_EQ(counter[0], kNumIterations); +} + } // namespace } // namespace xla::cpu diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 7d3c9c558021d4..49ed9d3d0b343c 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -755,9 +755,19 @@ absl::StatusOr ThunkEmitter::EmitWhileThunk( TF_ASSIGN_OR_RETURN(ThunkSequence body_thunk, EmitHloComputation(instruction->while_body())); + // Check if while loop has a statically known trip count. + TF_ASSIGN_OR_RETURN( + auto loop_config, + instruction->backend_config()); + + std::optional trip_count; + if (loop_config.has_known_trip_count()) { + trip_count = loop_config.known_trip_count().n(); + } + return ThunkSequence::Of(ThunkInfo(instruction), cond_buffer, std::move(cond_thunk), - std::move(body_thunk)); + std::move(body_thunk), trip_count); } absl::StatusOr ThunkEmitter::EmitDotThunk( From 4835a026950876338ea89dcdfed3389b665aa0c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 13:13:45 -0700 Subject: [PATCH 282/376] When decomposing mesh shapes into partial mesh shapes for iterative solving, take into account the collective performance cost (in the form of the alpha and beta values) corresponding to mesh axes, in addition to their sizes. PiperOrigin-RevId: 657705732 --- .../auto_sharding/auto_sharding.cc | 4 +++- .../auto_sharding/auto_sharding_util.cc | 24 ++++++++++++------- .../auto_sharding/auto_sharding_util.h | 9 ++++--- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index b0f17e9ffbddbf..273bded2d4d7c5 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3982,7 +3982,9 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( std::vector> partial_mesh_shapes; if (option_.solve_nd_sharding_iteratively) { // Generate partial mesh shapes to optimize iteratively. - partial_mesh_shapes = spmd::DecomposeMeshShapes(option_.device_mesh_shape); + partial_mesh_shapes = spmd::DecomposeMeshShapes(option_.device_mesh_shape, + option_.device_mesh_alpha, + option_.device_mesh_beta); } else { partial_mesh_shapes = {option_.device_mesh_shape}; } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 4b86f967ab7da2..cb8ec2c31f0abc 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -2079,26 +2079,32 @@ absl::StatusOr AdjustShardingsWithPartialMeshShape( } std::vector> DecomposeMeshShapes( - std::vector mesh_shape) { + const std::vector& mesh_shape, + const std::vector& mesh_alpha, + const std::vector& mesh_beta) { // Get the ranking order based on the size of each value. std::vector ranking_order; std::vector> partial_mesh_shapes; - std::vector> pairs(mesh_shape.size()); + std::vector> tuples( + mesh_shape.size()); for (size_t i = 0; i < mesh_shape.size(); i++) { - pairs[i] = {mesh_shape[i], i}; + // Here we prioritize the throughput term (beta) over the latency term + // (alpha), assuming that collectives are more often throughput-bound. This + // is currently somewhat of an arbitrary choice and can be changed. + tuples[i] = {mesh_beta[i], mesh_alpha[i], mesh_shape[i], i}; } // For vector of size 3, the sorted indices happen to be the same as their // rankings. mesh_shapes over 3 elements are not supported by AutoSharding. - std::sort(pairs.begin(), pairs.end(), - std::greater>()); + std::sort(tuples.begin(), tuples.end(), + std::greater>()); std::vector partial_mesh_shape(mesh_shape.size(), 1); // Starts from the largest dimension of mesh_shape. - for (size_t i = 0; i < pairs.size(); i++) { - if (pairs[i].first == 1) { - break; + for (size_t i = 0; i < tuples.size(); i++) { + if (std::get<2>(tuples[i]) == 1) { + continue; } - partial_mesh_shape[pairs[i].second] = pairs[i].first; + partial_mesh_shape[std::get<3>(tuples[i])] = std::get<2>(tuples[i]); // Needs to copy partial_mesh_shape. partial_mesh_shapes.push_back(partial_mesh_shape); } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index a4ea23c922fc06..64749c677c89af 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -629,10 +629,13 @@ inline bool AdjustShardingsWithPartialMeshShape( // Decompose mesh shapes into partial mesh shapes so that we can solve the auto // sharding problem iteratively. Returns partial mesh shapes with larger -// dimensions first. For example, input [1, 4, 2] returns [1, 4, 1] and [1, 4, -// 2]; input [4, 8, 2] returns [1, 8, 1], [4, 8, 1] and [ 4, 8, 2]. +// dimensions and more expensive collective costs first. For example, if all +// mesh axes all have collective costs, input [1, 4, 2] returns [1, 4, 1] and +// [1, 4, 2]; input [4, 8, 2] returns [1, 8, 1], [4, 8, 1] and [ 4, 8, 2]. std::vector> DecomposeMeshShapes( - std::vector mesh_shape); + const std::vector& mesh_shape, + const std::vector& mesh_alpha, + const std::vector& mesh_beta); bool OutputInputSameShapes(const HloInstruction* ins); From e3bcd73cba58284504849d26ac21bd795ca49cd6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 13:24:47 -0700 Subject: [PATCH 283/376] Add support for input_output_alias in hlo_to_mhlo and mhlo_to_hlo. PiperOrigin-RevId: 657709636 --- xla/translate/hlo_to_mhlo/BUILD | 1 - .../hlo_to_mhlo/hlo_function_importer.cc | 55 +++++++++++ .../hlo_to_mhlo/hlo_function_importer.h | 7 ++ .../hlo_to_mhlo/hlo_module_importer.cc | 4 + .../hlo_to_mhlo/tests/module_attributes.hlo | 13 +++ xla/translate/mhlo_to_hlo/BUILD | 1 + .../mhlo_to_hlo/attribute_exporter.cc | 95 +++++++++++++++++++ .../mhlo_to_hlo/attribute_exporter.h | 4 + xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 7 ++ .../mhlo_to_hlo/tests/module_attributes.mlir | 42 ++++++++ 10 files changed, 228 insertions(+), 1 deletion(-) diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD index 937094cb563f9c..d9c0ddc3cfc6c9 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD @@ -83,7 +83,6 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index e719db6c4bbeed..95e5bac0cbb374 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -2470,6 +2471,60 @@ absl::Status HloFunctionImporter::ConvertShapeToMlirLayout( return Internal("Couldn't convert layout."); } +// std::string FrontendAttributesToString( +// const FrontendAttributes& frontend_attributes) { +// std::vector> sorted_attributes( +// frontend_attributes.map().begin(), frontend_attributes.map().end()); +// absl::c_sort(sorted_attributes); +// const auto formatter = [](std::string* out, +// const std::pair& item) +// { +// if (LexesAsJsonDict(item.second)) { +// absl::StrAppend(out, item.first, "=", item.second); +// } else { +// absl::StrAppend(out, item.first, "=\"", item.second, "\""); +// } +// }; +// return absl::StrFormat("{%s}", +// absl::StrJoin(sorted_attributes, ",", formatter)); +// } + +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder) { + llvm::SmallVector element_attrs; + alias.ForEachAlias([&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + std::string kindToString; + switch (alias.kind) { + case HloInputOutputAliasConfig::AliasKind::kMayAlias: + kindToString = "may_alias"; + break; + case HloInputOutputAliasConfig::AliasKind::kMustAlias: + kindToString = "must_alias"; + break; + default: + kindToString = "undefined_alias"; + } + mlir::NamedAttribute alias_named_attributes[3] = { + builder->getNamedAttr( + "parameter_index", + builder->getDenseI64ArrayAttr(ArrayRef( + alias.parameter_index.begin(), alias.parameter_index.end()))), + builder->getNamedAttr("parameter_number", builder->getI64IntegerAttr( + alias.parameter_number)), + builder->getNamedAttr("kind", builder->getStringAttr(kindToString))}; + + mlir::NamedAttribute named_attributes[2] = { + builder->getNamedAttr("output_index", + builder->getDenseI64ArrayAttr(ArrayRef( + output_index.begin(), output_index.end()))), + builder->getNamedAttr( + "alias", builder->getDictionaryAttr(alias_named_attributes))}; + element_attrs.push_back(builder->getDictionaryAttr(named_attributes)); + }); + return builder->getArrayAttr(element_attrs); +} + mlir::Attribute ConvertSharding(const HloSharding& sharding, mlir::Builder* builder) { return builder->getStringAttr(sharding.ToString(/*include_metadata=*/true)); diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/xla/translate/hlo_to_mhlo/hlo_function_importer.h index cb3953990f4030..5c5a4e309bfbf6 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/ValueRange.h" #include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" @@ -297,6 +298,12 @@ class HloFunctionImporter { bool flatten_computation_args_result_; }; +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO C++ input_output_alias_config. +// Always succeeds and returns a non-empty attribute. +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder); + // Returns a StringAttr that carries a prettyprinted representation of the // given HLO C++ sharding. // Always succeeds and returns a non-empty attribute. diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index 1f2ea997c81e8a..76037442d52099 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -122,6 +122,10 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { ConvertSharding(hlo_module.spmd_output_sharding(), &builder_)); } + module->setAttr("mhlo.input_output_alias", + ConvertInputOutputAlias( + hlo_module.input_output_alias_config(), &builder_)); + if (hlo_module.has_spmd_parameters_shardings()) { llvm::SmallVector parameter_shardings; parameter_shardings.reserve(hlo_module.spmd_parameters_shardings().size()); diff --git a/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo b/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo index 74eaaea5a0e8fe..d3433dce372cbf 100644 --- a/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo +++ b/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo @@ -5,6 +5,18 @@ # FLATTEN-CHECK-LABEL: module @main attributes { hlo_module { name: "main" + input_output_alias { + entries { + output_shape_index: 0 + parameter_number: 0 + kind: MAY_ALIAS + } + entries { + output_shape_index: 1 + parameter_number: 1 + kind: MAY_ALIAS + } + } entry_computation_name: "main.5" computations { name: "main.5" @@ -217,6 +229,7 @@ hlo_module { value: "attr_value" } } +# CHECK-SAME: mhlo.input_output_alias = [{alias = {kind = "may_alias", parameter_index = array, parameter_number = 0 : i64}, output_index = array}, {alias = {kind = "may_alias", parameter_index = array, parameter_number = 1 : i64}, output_index = array}] # CHECK-SAME: mhlo.is_dynamic = true is_dynamic: true # CHECK-SAME: mhlo.use_auto_spmd_partitioning = true diff --git a/xla/translate/mhlo_to_hlo/BUILD b/xla/translate/mhlo_to_hlo/BUILD index 92b7265298f6e7..3de8007804af4b 100644 --- a/xla/translate/mhlo_to_hlo/BUILD +++ b/xla/translate/mhlo_to_hlo/BUILD @@ -23,6 +23,7 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/mlir_hlo", "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/xla/translate/mhlo_to_hlo/attribute_exporter.cc index a492861b28d831..8d54de6a5b9322 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.cc @@ -185,4 +185,99 @@ std::optional ConvertSharding(llvm::StringRef sharding) { return std::nullopt; } +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing) { + if (aliasing.empty()) return std::nullopt; + + xla::HloInputOutputAliasProto input_output_alias_proto; + for (auto attr : aliasing) { + auto entry_attr = mlir::cast(attr); + auto alias_attr = mlir::cast(entry_attr.get("alias")); + mlir::ArrayRef output_index = + mlir::cast(entry_attr.get("output_index")) + .asArrayRef(); + mlir::ArrayRef parameter_index = + mlir::cast(alias_attr.get("parameter_index")) + .asArrayRef(); + HloInputOutputAliasProto::AliasEntryProto entry; + entry.mutable_output_shape_index()->Add(output_index.begin(), + output_index.end()); + entry.set_parameter_number( + mlir::cast(alias_attr.get("parameter_number")) + .getInt()); + entry.mutable_parameter_shape_index()->Add(parameter_index.begin(), + parameter_index.end()); + mlir::StringRef kind = + mlir::cast(alias_attr.get("kind")).getValue(); + if (kind == "may_alias") + entry.set_kind(xla::Kind::MAY_ALIAS); + else if (kind == "must_alias") + entry.set_kind(xla::Kind::MUST_ALIAS); + else + entry.set_kind(xla::Kind::UNDEFINED_ALIAS); + input_output_alias_proto.add_entries()->Swap(&entry); + } + return input_output_alias_proto; +} + +DotDimensionNumbers ConvertDotDimensionNumbers( + mlir::mhlo::DotDimensionNumbersAttr input) { + DotDimensionNumbers output; + + for (auto v : input.getLhsBatchingDimensions()) { + output.add_lhs_batch_dimensions(v); + } + + for (auto v : input.getRhsBatchingDimensions()) { + output.add_rhs_batch_dimensions(v); + } + + for (auto v : input.getLhsContractingDimensions()) { + output.add_lhs_contracting_dimensions(v); + } + + for (auto v : input.getRhsContractingDimensions()) { + output.add_rhs_contracting_dimensions(v); + } + + return output; +} + +DotDimensionNumbers ConvertDotDimensionNumbers( + absl::Span lhs_batch, absl::Span lhs_contract, + absl::Span rhs_batch, + absl::Span rhs_contract) { + DotDimensionNumbers output; + for (auto v : lhs_batch) { + output.add_lhs_batch_dimensions(v); + } + + for (auto v : rhs_batch) { + output.add_rhs_batch_dimensions(v); + } + + for (auto v : lhs_contract) { + output.add_lhs_contracting_dimensions(v); + } + + for (auto v : rhs_contract) { + output.add_rhs_contracting_dimensions(v); + } + + return output; +} + +absl::StatusOr> ConvertMlirArrayAttrToInt64Array( + const mlir::ArrayAttr& array) { + int rank = array.size(); + std::vector converted_array(rank); + for (int i = 0; i < rank; i++) { + mlir::IntegerAttr attr = mlir::dyn_cast(array[i]); + if (!attr) { + return Internal("Type Error: Expected layout integer attribute"); + } + converted_array[i] = attr.getInt(); + } + return converted_array; +} } // namespace xla diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.h b/xla/translate/mhlo_to_hlo/attribute_exporter.h index e0e0dc9821d21e..49daefe6935650 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mlir/IR/Attributes.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" @@ -59,5 +60,8 @@ ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); // Will fail if both attempts at parsing failed. std::optional ConvertSharding(mlir::StringRef sharding); +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing); + } // namespace xla #endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 623080e11fd60d..90eb1a902127bc 100644 --- a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -3736,6 +3736,13 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, *hlo_module.mutable_spmd_output_sharding() = *xla::ConvertSharding(spmd_output_sharding.getValue()); } + if (auto input_output_alias = + module->getAttrOfType("mhlo.input_output_alias")) { + if (std::optional input_output_alias_proto = + xla::ConvertInputOutputAlias(input_output_alias.getValue())) { + *hlo_module.mutable_input_output_alias() = *input_output_alias_proto; + } + } if (auto spmd_parameters_sharding = module->getAttrOfType( "mhlo.spmd_parameters_shardings")) { for (const auto& sharding : spmd_parameters_sharding.getValue()) { diff --git a/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir b/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir index 049456bb09e6f7..6ad08374e5d2e6 100644 --- a/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir +++ b/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir @@ -100,3 +100,45 @@ module @ModuleWithFrontendAttributes attributes { func.return %arg0 : tensor<1xf32> } } + + + +// ----- + +module attributes { +// CHECK: input_output_alias { +// CHECK-NEXT: entries { +// CHECK-NEXT: output_shape_index: 0 +// CHECK-NEXT: kind: MAY_ALIAS +// CHECK-NEXT: } +// CHECK-NEXT: entries { +// CHECK-NEXT: output_shape_index: 1 +// CHECK-NEXT: parameter_number: 1 +// CHECK-NEXT: kind: MAY_ALIAS +// CHECK-NEXT: } +// CHECK-NEXT: } + mhlo.input_output_alias = [ + { + alias = + { + kind = "may_alias", + parameter_index = array, + parameter_number = 0 : i64 + }, + output_index = array + }, + { + alias = + { + kind = "may_alias", + parameter_index = array, + parameter_number = 1 : i64 + }, + output_index = array + } +] +} { + func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32> ) -> (tensor<1xf32>, tensor<1xf32>) { + func.return %arg0, %arg1: tensor<1xf32>, tensor<1xf32> + } +} \ No newline at end of file From 6bb25b3868f8bcc82c41b34c2b1e87694e41801f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 13:26:54 -0700 Subject: [PATCH 284/376] TupleSimplifier needs to update schedule if there is a schedule. Fix the wrong module passed in hlo_runner_pjrt.cc. PiperOrigin-RevId: 657710296 --- xla/service/BUILD | 1 + xla/service/hlo_runner_pjrt.cc | 4 +++- xla/service/tuple_simplifier.cc | 6 ++++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 21f3e28807fd74..32dae31309906a 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -4022,6 +4022,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/hlo_runner_pjrt.cc b/xla/service/hlo_runner_pjrt.cc index ccb239aee5f351..3965bf61870f3a 100644 --- a/xla/service/hlo_runner_pjrt.cc +++ b/xla/service/hlo_runner_pjrt.cc @@ -369,7 +369,9 @@ absl::StatusOr> HloRunnerPjRt::CreateExecutable( CreateExecutable(module.get(), compile_options)); auto executable = std::make_unique( - std::shared_ptr(std::move(module)), pjrt_executable.release()); + std::shared_ptr( + std::move(pjrt_executable->GetHloModules().value()[0])), + pjrt_executable.release()); std::unique_ptr exec = static_cast>(executable.release()); diff --git a/xla/service/tuple_simplifier.cc b/xla/service/tuple_simplifier.cc index ae033b79ba917a..3557b076df0ef6 100644 --- a/xla/service/tuple_simplifier.cc +++ b/xla/service/tuple_simplifier.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { @@ -116,6 +117,11 @@ absl::StatusOr TupleSimplifier::Run( } } } + + if (module->has_schedule()) { + TF_RETURN_IF_ERROR(module->schedule().Update()); + } + return changed; } From d35129593a704103e885a5f3a2fc9f0b7076bbdd Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 30 Jul 2024 14:20:38 -0700 Subject: [PATCH 285/376] Internal copybara config change PiperOrigin-RevId: 657728755 --- third_party/tsl/tsl/profiler/rpc/client/profiler_client.cc | 2 +- third_party/tsl/tsl/profiler/rpc/profiler_server.cc | 2 +- third_party/tsl/tsl/profiler/rpc/profiler_service_impl.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/tsl/tsl/profiler/rpc/client/profiler_client.cc b/third_party/tsl/tsl/profiler/rpc/client/profiler_client.cc index 8bc9a1986effb7..47d8638005931c 100644 --- a/third_party/tsl/tsl/profiler/rpc/client/profiler_client.cc +++ b/third_party/tsl/tsl/profiler/rpc/client/profiler_client.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "grpcpp/grpcpp.h" #include "absl/memory/memory.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "grpcpp/grpcpp.h" // IWYU pragma: keep #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" diff --git a/third_party/tsl/tsl/profiler/rpc/profiler_server.cc b/third_party/tsl/tsl/profiler/rpc/profiler_server.cc index 8b598fa450cdc6..f619c1346a0af4 100644 --- a/third_party/tsl/tsl/profiler/rpc/profiler_server.cc +++ b/third_party/tsl/tsl/profiler/rpc/profiler_server.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include "grpcpp/grpcpp.h" #include "absl/strings/str_cat.h" +#include "grpcpp/grpcpp.h" // IWYU pragma: keep #include "tsl/platform/logging.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" diff --git a/third_party/tsl/tsl/profiler/rpc/profiler_service_impl.cc b/third_party/tsl/tsl/profiler/rpc/profiler_service_impl.cc index 8deee9782aa9fe..efb544ebdf2278 100644 --- a/third_party/tsl/tsl/profiler/rpc/profiler_service_impl.cc +++ b/third_party/tsl/tsl/profiler/rpc/profiler_service_impl.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "grpcpp/support/status.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_replace.h" +#include "grpcpp/support/status.h" #include "tsl/platform/env.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" From 78418c6a4a60cff6b3ad538de43602d23dc2aaf5 Mon Sep 17 00:00:00 2001 From: Carlos Guia Date: Tue, 30 Jul 2024 14:23:11 -0700 Subject: [PATCH 286/376] [XLA] Change attribute name skip-simplify-while-loops/trip-count-one to skip-simplify-while-loops_trip-count-one because / is not valid for attribute names. PiperOrigin-RevId: 657729693 --- xla/service/while_loop_simplifier.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xla/service/while_loop_simplifier.cc b/xla/service/while_loop_simplifier.cc index 2ca8d1884a80bb..4cc642c3994e18 100644 --- a/xla/service/while_loop_simplifier.cc +++ b/xla/service/while_loop_simplifier.cc @@ -954,8 +954,8 @@ static absl::StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // inline the call. const auto& attrs = while_op->frontend_attributes().map(); bool skip_trip_count_one_simplification = - attrs.contains("skip-simplify-while-loops/trip-count-one") && - (attrs.at("skip-simplify-while-loops/trip-count-one") == "true"); + attrs.contains("skip-simplify-while-loops_trip-count-one") && + (attrs.at("skip-simplify-while-loops_trip-count-one") == "true"); if (trip_count && *trip_count == 1 && !skip_trip_count_one_simplification) { // Do not simplify the loop away when there is a side-effectful op, // otherwise the infeed op may not inherit the data dependency from From 11fc7d928a21c9a6365036142862fd93795bc12b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 15:32:32 -0700 Subject: [PATCH 287/376] Export entry parameter layout tiles PiperOrigin-RevId: 657752363 --- .../hlo_to_mhlo/hlo_function_importer.cc | 72 +++++++++++++------ .../tests/entry_computation_layout.hlo | 12 +++- 2 files changed, 62 insertions(+), 22 deletions(-) diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index 95e5bac0cbb374..c641270aed10c8 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -337,26 +337,42 @@ static bool HasCustomLayout(const Shape& shape) { shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); } -static mlir::Attribute GetLayoutAttribute(mlir::Builder& b, - const Shape& shape) { +static std::pair GetLayoutAttribute( + mlir::Builder& b, const Shape& shape, + std::optional maybe_layout = std::nullopt) { if (shape.IsTuple()) { llvm::SmallVector element_attrs; + llvm::SmallVector tile_attrs; for (const auto& tuple_shape : shape.tuple_shapes()) { - element_attrs.push_back(GetLayoutAttribute(b, tuple_shape)); - } - return b.getArrayAttr(element_attrs); + // TODO here we do not disect the layout of a tuple into sublayouts. + // Presently ShapeLayout cannot represent an explicit layout for a tuple + // type so this should never occur. However, if this function were to + // be used in another context where this assumption were to be lifted. + // users should be aware of this limitation which will use the default + // layout for tuple subshapes. + std::pair inner = + GetLayoutAttribute(b, tuple_shape); + element_attrs.push_back(inner.first); + tile_attrs.push_back(inner.second); + } + return std::make_pair((mlir::Attribute)b.getArrayAttr(element_attrs), + b.getArrayAttr(tile_attrs)); } - llvm::SmallVector layout; - if (shape.has_layout()) { - layout = {shape.layout().minor_to_major().begin(), - shape.layout().minor_to_major().end()}; - } else { - Layout layout_for_shape = LayoutUtil::GetDefaultLayoutForShape(shape); - layout = {layout_for_shape.minor_to_major().begin(), - layout_for_shape.minor_to_major().end()}; + Layout layout = maybe_layout.value_or( + shape.has_layout() ? shape.layout() + : LayoutUtil::GetDefaultLayoutForShape(shape)); + + llvm::SmallVector vec_of_tiles; + for (const Tile& tile : layout.tiles()) { + llvm::SmallVector tile_vec = {tile.dimensions().begin(), + tile.dimensions().end()}; + vec_of_tiles.push_back(b.getIndexTensorAttr(tile_vec)); } - return b.getIndexTensorAttr(layout); + llvm::SmallVector layout_vec = {layout.minor_to_major().begin(), + layout.minor_to_major().end()}; + return std::make_pair(b.getIndexTensorAttr(layout_vec), + b.getArrayAttr(vec_of_tiles)); } mlir::Attribute GetFrontendAttributes(mlir::Builder& b, @@ -598,24 +614,38 @@ absl::StatusOr HloFunctionImporter::ImportAsFunc( if (computation.IsEntryComputation()) { const auto& computation_layout = computation.parent()->entry_computation_layout(); - if (computation_layout.LayoutIsSet()) { + if (computation_layout.LayoutIsSet() && + !computation_layout.result_layout().shape().IsTuple()) { if (HasCustomLayout(computation_layout.result_layout().shape())) { - function->setAttr( - "xla_entry_computation_result_layout", + std::pair layout_attrs = GetLayoutAttribute(*builder_, - computation_layout.result_layout().shape())); + computation_layout.result_layout().shape(), + computation_layout.result_layout().layout()); + function->setAttr("xla_entry_computation_result_layout", + layout_attrs.first); + function->setAttr("xla_entry_computation_result_tiles", + layout_attrs.second); } if (llvm::any_of(computation_layout.parameter_layouts(), [](const ShapeLayout& shape) { return HasCustomLayout(shape.shape()); })) { llvm::SmallVector parameter_layouts; + llvm::SmallVector parameter_tiles; for (auto& layout : computation_layout.parameter_layouts()) { - parameter_layouts.push_back( - GetLayoutAttribute(*builder_, layout.shape())); + std::pair layout_attrs = + GetLayoutAttribute( + *builder_, layout.shape(), + (layout.LayoutIsSet() && !layout.shape().IsTuple()) + ? std::optional(layout.layout()) + : std::nullopt); + parameter_layouts.push_back(layout_attrs.first); + parameter_tiles.push_back(layout_attrs.second); } function->setAttr("xla_entry_computation_parameter_layouts", builder_->getArrayAttr(parameter_layouts)); + function->setAttr("xla_entry_computation_parameter_tiles", + builder_->getArrayAttr(parameter_tiles)); } } } @@ -2441,7 +2471,7 @@ void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op, const Shape& shape, llvm::StringRef attr_name) { mlir::Builder b(op->getContext()); - op->setAttr(attr_name, GetLayoutAttribute(b, shape)); + op->setAttr(attr_name, GetLayoutAttribute(b, shape).first); } absl::Status HloFunctionImporter::ConvertShapeToMlirLayout( diff --git a/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo b/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo index fa99b77174cb53..253639908966b5 100644 --- a/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo +++ b/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo @@ -1,11 +1,12 @@ // RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s -HloModule entry, entry_computation_layout={(f32[2,3,4]{0,1,2}, f32[2,3,4]{1,2,0}, (f32[1,2]{1,0}, f32[1,2]{0,1}))->f32[2,3,4]{2,0,1}} +HloModule entry, entry_computation_layout={(f32[2,3,4]{0,1,2}, f32[2,3,4]{1,2,0}, (f32[1,2]{1,0}, f32[1,2]{0,1}), s32[]{:T(128)})->f32[2,3,4]{2,0,1}} ENTRY entry { p0 = f32[2,3,4]{2,1,0} parameter(0) p1 = f32[2,3,4]{2,1,0} parameter(1) p2 = (f32[1,2]{1,0}, f32[1,2]{0,1}) parameter(2) + p3 = s32[]{:T(128)} parameter(3) ROOT add = f32[2,3,4]{2,1,0} add(p0, p1) } @@ -15,4 +16,13 @@ ENTRY entry { // CHECK-SAME: dense<[1, 2, 0]> // CHECK-SAME: [dense<[1, 0]> // CHECK-SAME: , dense<[0, 1]> +// CHECK-SAME: xla_entry_computation_parameter_tiles = [ +// CHECK-SAME: [] +// CHECK-SAME: [], +// CHECK-SAME: [ +// CHECK-SAME: [], +// CHECK-SAME: [] +// CHECK-SAME: ], +// CHECK-SAME: [dense<128> : tensor<1xindex>] +// CHECK-SAME: ] // CHECK-SAME: xla_entry_computation_result_layout = dense<[2, 0, 1]> From 744fe3eeef70107e35eec8f90d31ab80653aa57c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 30 Jul 2024 15:58:22 -0700 Subject: [PATCH 288/376] [xla:cpu] Micro-optimizations for BufferAllocations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_SelectAndScatterF32/128/process_time 420µs ± 1% 401µs ± 1% -4.67% BM_SelectAndScatterF32/256/process_time 1.73ms ± 2% 1.65ms ± 3% -4.48% BM_SelectAndScatterF32/512/process_time 7.73ms ± 1% 7.41ms ± 2% -4.14% name old time/op new time/op delta BM_SelectAndScatterF32/128/process_time 421µs ± 1% 401µs ± 1% -4.69% BM_SelectAndScatterF32/256/process_time 1.73ms ± 2% 1.65ms ± 3% -4.57% BM_SelectAndScatterF32/512/process_time 7.34ms ± 1% 7.02ms ± 2% -4.46% name old INSTRUCTIONS/op new INSTRUCTIONS/op delta BM_SelectAndScatterF32/128/process_time 4.55M ± 0% 4.20M ± 0% -7.51% BM_SelectAndScatterF32/256/process_time 18.4M ± 0% 17.0M ± 0% -7.54% BM_SelectAndScatterF32/512/process_time 74.9M ± 0% 69.3M ± 0% -7.48% PiperOrigin-RevId: 657760541 --- xla/service/cpu/runtime/buffer_allocations.h | 21 +++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/xla/service/cpu/runtime/buffer_allocations.h b/xla/service/cpu/runtime/buffer_allocations.h index fe26d441359b76..4d757261c5a39e 100644 --- a/xla/service/cpu/runtime/buffer_allocations.h +++ b/xla/service/cpu/runtime/buffer_allocations.h @@ -45,7 +45,7 @@ class BufferAllocations { // Same as above, but also adjusts the returned address for the offset and // size contained in the given slice. absl::StatusOr GetDeviceAddress( - const BufferAllocation::Slice& slice) const; + BufferAllocation::Slice slice) const; // Unchecked version of `GetDeviceAddress` that does not check the buffer // index and assumes it is valid. @@ -55,16 +55,19 @@ class BufferAllocations { // Unchecked version of `GetDeviceAddress` that does not check the slice // buffer index, offset and size and assumes they all are valid. se::DeviceMemoryBase GetDeviceAddressUnchecked( - const BufferAllocation::Slice& slice) const; + BufferAllocation::Slice slice) const; private: std::vector buffers_; + se::DeviceMemoryBase* buffers_data_; // buffers_.data() size_t num_buffers_; }; inline BufferAllocations::BufferAllocations( absl::Span buffers) - : buffers_(buffers.size()), num_buffers_(buffers_.size()) { + : buffers_(buffers.size()), + buffers_data_(buffers_.data()), + num_buffers_(buffers_.size()) { for (size_t i = 0; i < buffers.size(); ++i) { buffers_[i] = buffers[i].AsDeviceMemoryBase(); } @@ -82,8 +85,7 @@ BufferAllocations::GetDeviceAddress(BufferAllocation::Index index) const { } inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr -BufferAllocations::GetDeviceAddress( - const BufferAllocation::Slice& slice) const { +BufferAllocations::GetDeviceAddress(BufferAllocation::Slice slice) const { // Handle empty slices explicitly and return a null pointer device memory to // guarantee that we do not accidentally write through the empty slice which // would hide a real bug in the code. @@ -97,7 +99,7 @@ BufferAllocations::GetDeviceAddress( "Invalid buffer index %d. It must be in the range [0, %d)", index, num_buffers_); } - const se::DeviceMemoryBase& base = buffers_[index]; + const se::DeviceMemoryBase& base = buffers_data_[index]; int64_t offset = slice.offset(); int64_t extent = offset + slice.size(); @@ -125,15 +127,16 @@ BufferAllocations::GetDeviceAddress( inline ABSL_ATTRIBUTE_ALWAYS_INLINE se::DeviceMemoryBase BufferAllocations::GetDeviceAddressUnchecked( BufferAllocation::Index buffer_index) const { - return buffers_[buffer_index]; + return buffers_data_[buffer_index]; } // Unchecked version of `GetDeviceAddress` that does not check the slice // buffer index, offset and size and assumes they are valid. inline ABSL_ATTRIBUTE_ALWAYS_INLINE se::DeviceMemoryBase BufferAllocations::GetDeviceAddressUnchecked( - const BufferAllocation::Slice& slice) const { - return buffers_[slice.index()].GetByteSlice(slice.offset(), slice.size()); + BufferAllocation::Slice slice) const { + return buffers_data_[slice.index()].GetByteSlice(slice.offset(), + slice.size()); } } // namespace xla::cpu From 48c7b643f5c97034541c1489a5bf5d3ea2ba018e Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 30 Jul 2024 16:03:57 -0700 Subject: [PATCH 289/376] Update users of `status_test_util` to use the new location in `xla/tsl` PiperOrigin-RevId: 657762370 --- xla/BUILD | 6 +- xla/backends/profiler/cpu/BUILD | 2 +- xla/backends/profiler/cpu/host_tracer_test.cc | 2 +- xla/client/lib/BUILD | 2 +- xla/client/lib/math_test.cc | 2 +- xla/examples/axpy/BUILD | 2 +- xla/examples/axpy/stablehlo_compile_test.cc | 2 +- xla/ffi/BUILD | 10 +- xla/ffi/api/BUILD | 2 +- xla/ffi/api/ffi_test.cc | 2 +- xla/ffi/call_frame_test.cc | 2 +- xla/ffi/execution_context_test.cc | 2 +- xla/ffi/execution_state_test.cc | 2 +- xla/ffi/ffi_test.cc | 2 +- xla/hlo/experimental/auto_sharding/BUILD | 2 +- .../auto_sharding/auto_sharding_test.cc | 2 +- xla/hlo/ir/BUILD | 2 +- xla/hlo/ir/backend_config_test.cc | 2 +- xla/hlo/transforms/BUILD | 2 +- .../transforms/hlo_constant_splitter_test.cc | 2 +- xla/hlo/utils/BUILD | 2 +- xla/hlo/utils/hlo_live_range_test.cc | 2 +- xla/literal_comparison_test.cc | 2 +- xla/literal_test.cc | 2 +- xla/mlir/utils/BUILD | 2 +- xla/mlir/utils/error_util_test.cc | 2 +- xla/pjrt/BUILD | 8 +- xla/pjrt/c/BUILD | 2 +- xla/pjrt/c/pjrt_c_api_helpers_test.cc | 2 +- xla/pjrt/cpu/BUILD | 4 +- xla/pjrt/cpu/cpu_client_test.cc | 2 +- xla/pjrt/cpu/gloo_collectives_test.cc | 2 +- xla/pjrt/distributed/BUILD | 4 +- xla/pjrt/distributed/client_server_test.cc | 2 +- xla/pjrt/distributed/topology_util_test.cc | 2 +- xla/pjrt/gpu/BUILD | 2 +- xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 2 +- xla/pjrt/host_callback_test.cc | 2 +- xla/pjrt/pjrt_api_test.cc | 2 +- xla/pjrt/pjrt_c_api_client_test.cc | 2 +- xla/pjrt/pjrt_stream_executor_client_test.cc | 2 +- xla/python/ifrt/BUILD | 14 +- xla/python/ifrt/array_impl_test_lib.cc | 2 +- .../ifrt/custom_call_program_serdes_test.cc | 2 +- xla/python/ifrt/future_test.cc | 2 +- xla/python/ifrt/ir/tests/BUILD | 2 +- .../ifrt/ir/tests/executable_impl_test_lib.cc | 2 +- xla/python/ifrt/plugin_program_serdes_test.cc | 2 +- xla/python/ifrt/remap_impl_test_lib.cc | 2 +- xla/python/ifrt/support/BUILD | 2 +- .../ifrt/support/sharding_conversions_test.cc | 2 +- xla/python/ifrt/test_util.h | 2 +- xla/python/ifrt/tuple_impl_test_lib.cc | 2 +- xla/python/ifrt_proxy/server/BUILD | 2 +- .../ifrt_proxy/server/ifrt_backend_test.cc | 2 +- xla/python/pjrt_ifrt/BUILD | 4 +- .../pjrt_ifrt/basic_string_array_test.cc | 2 +- .../pjrt_ifrt/xla_executable_impl_test_lib.cc | 2 +- xla/service/BUILD | 142 +++++++++--------- xla/service/algebraic_simplifier_test.cc | 2 +- xla/service/all_reduce_simplifier_test.cc | 2 +- xla/service/all_reduce_splitter_test.cc | 2 +- xla/service/async_collective_creator_test.cc | 2 +- xla/service/bitcast_dtypes_expander_test.cc | 2 +- xla/service/buffer_assignment_test.cc | 2 +- xla/service/call_graph_test.cc | 2 +- xla/service/call_inliner_test.cc | 2 +- xla/service/collective_ops_utils_test.cc | 2 +- xla/service/collective_quantizer_test.cc | 2 +- ...ollective_transformation_reorderer_test.cc | 2 +- xla/service/compilation_environments_test.cc | 2 +- xla/service/compiler_test.cc | 2 +- xla/service/conditional_canonicalizer_test.cc | 2 +- xla/service/conditional_code_motion_test.cc | 2 +- xla/service/conditional_simplifier_test.cc | 2 +- .../convert_async_collectives_to_sync_test.cc | 2 +- xla/service/cpu/BUILD | 4 +- xla/service/cpu/runtime/BUILD | 20 +-- xla/service/cpu/tests/BUILD | 2 +- .../cpu/tests/cpu_spmd_compile_test.cc | 2 +- xla/service/cpu/xfeed_manager_test.cc | 2 +- xla/service/cpu_gpu_shape_verifier_test.cc | 2 +- .../dfs_hlo_visitor_with_default_test.cc | 2 +- xla/service/dot_merger_test.cc | 2 +- xla/service/dump_test.cc | 2 +- .../dynamic_dimension_inference_test.cc | 2 +- .../dynamic_dimension_simplifier_test.cc | 2 +- xla/service/dynamic_padder_test.cc | 2 +- xla/service/dynamic_parameter_binding_test.cc | 2 +- xla/service/flatten_call_graph_test.cc | 2 +- xla/service/generic_transfer_manager_test.cc | 2 +- xla/service/gpu/BUILD | 42 +++--- .../gpu/alias_passthrough_params_test.cc | 2 +- xla/service/gpu/autotuner_util_test.cc | 2 +- .../gpu/command_buffer_scheduling_test.cc | 2 +- xla/service/gpu/conv_algorithm_picker_test.cc | 2 +- .../gpu/cudnn_fused_conv_rewriter_test.cc | 2 +- .../gpu/cudnn_fused_mha_rewriter_test.cc | 2 +- .../gpu/cudnn_simplify_padding_test.cc | 2 +- xla/service/gpu/custom_call_test.cc | 2 +- xla/service/gpu/fusions/BUILD | 14 +- .../gpu/fusions/concatenate_mlir_test.cc | 2 +- xla/service/gpu/fusions/cudnn_test.cc | 2 +- ...in_place_dynamic_update_slice_mlir_test.cc | 2 +- xla/service/gpu/fusions/loop_mlir_test.cc | 2 +- xla/service/gpu/fusions/mlir/BUILD | 2 +- .../mlir/elemental_hlo_to_mlir_test.cc | 2 +- .../gpu/fusions/reduction_mlir_test.cc | 2 +- xla/service/gpu/fusions/scatter_mlir_test.cc | 2 +- .../gpu/fusions/transpose_mlir_test.cc | 2 +- xla/service/gpu/fusions/triton/BUILD | 8 +- ...riton_fusion_emitter_device_legacy_test.cc | 2 +- .../triton_fusion_emitter_device_test.cc | 2 +- .../triton_fusion_emitter_mem_utils_test.cc | 2 +- .../triton/triton_support_legacy_test.cc | 2 +- xla/service/gpu/gemm_algorithm_picker_test.cc | 2 +- xla/service/gpu/gemm_fusion_autotuner_test.cc | 2 +- xla/service/gpu/gpu_compiler_test.cc | 2 +- ..._convert_async_collectives_to_sync_test.cc | 2 +- .../gpu/gpu_latency_hiding_scheduler_test.cc | 2 +- xla/service/gpu/gpu_offloading_test.cc | 2 +- xla/service/gpu/hlo_fusion_stats_test.cc | 2 +- .../gpu/horizontal_loop_fusion_test.cc | 2 +- xla/service/gpu/kernel_reuse_cache_test.cc | 2 +- xla/service/gpu/kernels/BUILD | 4 +- .../cutlass_gemm_custom_kernel_test.cc | 2 +- .../gpu/kernels/topk_custom_kernel_test.cc | 2 +- xla/service/gpu/model/BUILD | 2 +- .../gpu/model/symbolic_tile_analysis_test.cc | 2 +- xla/service/gpu/nvptx_compiler_test.cc | 2 +- xla/service/gpu/runtime/BUILD | 6 +- .../gpu/runtime/command_buffer_cmd_test.cc | 2 +- .../gpu/runtime/command_buffer_thunk_test.cc | 2 +- .../gpu/runtime/dynamic_slice_thunk_test.cc | 2 +- xla/service/gpu/split_k_gemm_rewriter_test.cc | 2 +- xla/service/gpu/tests/BUILD | 8 +- xla/service/gpu/tests/gemm_rewrite_test.cc | 2 +- xla/service/gpu/tests/gpu_sparse_dot_test.cc | 2 +- .../gpu/tests/gpu_spmd_e2e_compile_test.cc | 2 +- .../gpu/tests/simple_optimization_test.cc | 2 +- .../triton_fusion_numerics_verifier_test.cc | 2 +- .../gpu_compilation_environment_test.cc | 2 +- xla/service/heap_simulator/BUILD | 2 +- .../heap_simulator/heap_simulator_test.cc | 2 +- xla/service/hlo_alias_analysis_test.cc | 2 +- xla/service/hlo_dataflow_analysis_test.cc | 2 +- xla/service/hlo_dce_test.cc | 2 +- xla/service/hlo_domain_test.cc | 2 +- .../hlo_input_output_alias_config_test.cc | 2 +- xla/service/hlo_instruction_test.cc | 2 +- xla/service/hlo_memory_scheduler_test.cc | 2 +- xla/service/hlo_module_dce_test.cc | 2 +- xla/service/hlo_module_group_test.cc | 2 +- xla/service/hlo_module_test.cc | 2 +- xla/service/hlo_ordering_test.cc | 2 +- xla/service/hlo_parser_test.cc | 2 +- xla/service/hlo_pass_pipeline_test.cc | 2 +- xla/service/hlo_rematerialization_test.cc | 2 +- xla/service/hlo_replication_analysis_test.cc | 2 +- xla/service/hlo_schedule_test.cc | 2 +- xla/service/hlo_verifier_test.cc | 2 +- .../host_memory_transfer_asyncifier_test.cc | 2 +- xla/service/host_offload_legalize_test.cc | 2 +- xla/service/host_offloader_test.cc | 2 +- xla/service/host_offloading_prepare_test.cc | 2 +- xla/service/layout_assignment_test.cc | 2 +- .../mapped_ptr_container_sorter_test.cc | 2 +- xla/service/memory_space_assignment/BUILD | 8 +- .../cost_analysis_test.cc | 2 +- .../memory_bound_loop_optimizer_test.cc | 2 +- .../memory_space_assignment_test.cc | 2 +- .../memory_space_assignment/simulator_test.cc | 2 +- xla/service/memory_space_propagation_test.cc | 2 +- .../profile_guided_latency_estimator_test.cc | 2 +- xla/service/real_imag_expander_test.cc | 2 +- xla/service/reshape_mover_test.cc | 2 +- xla/service/scatter_expander_test.cc | 2 +- xla/service/slice_sinker_test.cc | 2 +- xla/service/sort_simplifier_test.cc | 2 +- xla/service/spmd/BUILD | 2 +- xla/service/spmd/spmd_partitioner_test.cc | 2 +- xla/service/stable_sort_expander_test.cc | 2 +- xla/service/topk_rewriter_test.cc | 2 +- xla/service/triangular_solve_expander_test.cc | 2 +- xla/service/tuple_simplifier_test.cc | 2 +- .../while_loop_all_reduce_code_motion_test.cc | 2 +- .../while_loop_concat_code_motion_test.cc | 2 +- .../while_loop_invariant_code_motion_test.cc | 2 +- xla/service/while_loop_simplifier_test.cc | 2 +- xla/service/xla_aot_compile_cpu_test.cc | 2 +- xla/service/xla_aot_compile_gpu_test.cc | 2 +- .../xla_aot_compile_stablehlo_cpu_test.cc | 2 +- xla/stream_executor/gpu/BUILD | 10 +- .../gpu/gpu_command_buffer_test.cc | 2 +- .../gpu/gpu_device_info_test.cc | 2 +- xla/stream_executor/gpu/gpu_kernel_test.cc | 2 +- xla/stream_executor/gpu/memcpy_test.cc | 2 +- .../gpu/redzone_allocator_test.cc | 2 +- xla/stream_executor/host/BUILD | 4 +- xla/stream_executor/host/host_kernel_test.cc | 2 +- xla/stream_executor/host/host_stream_test.cc | 2 +- xla/tests/BUILD | 44 +++--- xla/tests/buffer_donation_test.cc | 2 +- xla/tests/cholesky_test.cc | 2 +- xla/tests/collective_ops_test.cc | 2 +- xla/tests/compute_constant_test.cc | 2 +- xla/tests/constants_test.cc | 2 +- xla/tests/hlo_test_base.cc | 2 +- xla/tests/llvm_irgen_test_base.cc | 2 +- xla/tests/multioutput_fusion_test.cc | 2 +- xla/tests/multiple_devices_on_host_test.cc | 2 +- xla/tests/multithreaded_compilation_test.cc | 2 +- .../outfeed_in_nested_computation_test.cc | 2 +- xla/tests/pred_test.cc | 2 +- xla/tests/reduce_test.cc | 2 +- xla/tests/reduce_window_test.cc | 2 +- xla/tests/replicated_io_feed_test.cc | 2 +- xla/tests/test_utils_test.cc | 2 +- xla/tests/triangular_solve_test.cc | 2 +- xla/tests/tuple_test.cc | 2 +- xla/tests/value_inference_test.cc | 2 +- xla/tests/while_test.cc | 2 +- xla/tests/xla_hlo_profile_test.cc | 2 +- xla/text_literal_writer_test.cc | 2 +- xla/tools/BUILD | 10 +- xla/tools/hlo_control_flow_flattening_test.cc | 2 +- xla/tools/hlo_module_loader_test.cc | 2 +- xla/tools/multihost_hlo_runner/BUILD | 2 +- .../functional_hlo_runner_test.cc | 2 +- xla/tools/run_hlo_module_bin_test.cc | 2 +- xla/tools/run_hlo_module_test.cc | 2 +- xla/tools/xla_compile_lib_test.cc | 2 +- xla/translate/hlo_to_mhlo/BUILD | 2 +- xla/translate/hlo_to_mhlo/hlo_utils_test.cc | 2 +- .../distributed_runtime/coordination/BUILD | 6 +- .../coordination_service_agent_test.cc | 2 +- ...ordination_service_recoverable_job_test.cc | 2 +- .../coordination/coordination_service_test.cc | 2 +- xla/tsl/distributed_runtime/rpc/BUILD | 2 +- .../rpc/grpc_channel_test.cc | 2 +- xla/tsl/framework/BUILD | 2 +- xla/tsl/framework/device_id_utils_test.cc | 2 +- xla/tsl/util/BUILD | 2 +- xla/tsl/util/device_name_utils_test.cc | 2 +- 244 files changed, 412 insertions(+), 412 deletions(-) diff --git a/xla/BUILD b/xla/BUILD index 59fad44977cffc..3b58f29ad59b27 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -632,13 +632,13 @@ xla_cc_test( ":types", ":util", ":xla_data_proto_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base", "@com_google_absl//absl/hash", "@com_google_absl//absl/random", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:macros", @@ -692,8 +692,8 @@ xla_cc_test( ":literal_comparison", ":literal_util", ":test_helpers", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test_main", ], @@ -960,7 +960,7 @@ xla_cc_test( ":test_helpers", ":text_literal_writer", ":types", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:test_main", ], diff --git a/xla/backends/profiler/cpu/BUILD b/xla/backends/profiler/cpu/BUILD index 45254705f39984..1f6de1be4932ee 100644 --- a/xla/backends/profiler/cpu/BUILD +++ b/xla/backends/profiler/cpu/BUILD @@ -128,9 +128,9 @@ xla_cc_test( srcs = ["host_tracer_test.cc"], deps = [ ":host_tracer_impl", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:blocking_counter", "@tsl//tsl/platform:env", "@tsl//tsl/platform:test", diff --git a/xla/backends/profiler/cpu/host_tracer_test.cc b/xla/backends/profiler/cpu/host_tracer_test.cc index 0db9f800b958ee..2fca882f9910d8 100644 --- a/xla/backends/profiler/cpu/host_tracer_test.cc +++ b/xla/backends/profiler/cpu/host_tracer_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include "absl/types/optional.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/xla/client/lib/BUILD b/xla/client/lib/BUILD index 2f26bba6a20923..c8447d3bffa3fe 100644 --- a/xla/client/lib/BUILD +++ b/xla/client/lib/BUILD @@ -248,8 +248,8 @@ xla_test( "//xla/tests:client_library_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", ], ) diff --git a/xla/client/lib/math_test.cc b/xla/client/lib/math_test.cc index 559302f6bb5977..0c5776f4bea333 100644 --- a/xla/client/lib/math_test.cc +++ b/xla/client/lib/math_test.cc @@ -38,9 +38,9 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/examples/axpy/BUILD b/xla/examples/axpy/BUILD index d3d20f7fc56dd0..f598d0a5242723 100644 --- a/xla/examples/axpy/BUILD +++ b/xla/examples/axpy/BUILD @@ -22,12 +22,12 @@ xla_cc_test( "//xla/service/cpu:cpu_compiler", "//xla/stream_executor:platform", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@stablehlo//:register", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", diff --git a/xla/examples/axpy/stablehlo_compile_test.cc b/xla/examples/axpy/stablehlo_compile_test.cc index 0bf61a17caf280..897a1e953d20f8 100644 --- a/xla/examples/axpy/stablehlo_compile_test.cc +++ b/xla/examples/axpy/stablehlo_compile_test.cc @@ -41,7 +41,7 @@ limitations under the License. #include "xla/service/stream_pool.h" #include "xla/stream_executor/platform.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/xla/ffi/BUILD b/xla/ffi/BUILD index b676e56fc67fdc..8ce222c9d1d0c5 100644 --- a/xla/ffi/BUILD +++ b/xla/ffi/BUILD @@ -45,9 +45,9 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/stream_executor:device_memory", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_benchmark", @@ -76,8 +76,8 @@ xla_cc_test( deps = [ ":execution_context", ":type_id_registry", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -103,8 +103,8 @@ xla_cc_test( srcs = ["execution_state_test.cc"], deps = [ ":execution_state", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -197,12 +197,12 @@ xla_cc_test( "//xla/ffi/api:c_api", "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -230,8 +230,8 @@ xla_cc_test( srcs = ["type_id_registry_test.cc"], deps = [ ":type_id_registry", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", diff --git a/xla/ffi/api/BUILD b/xla/ffi/api/BUILD index d15274c323c29a..d08bf9c377a09e 100644 --- a/xla/ffi/api/BUILD +++ b/xla/ffi/api/BUILD @@ -85,10 +85,10 @@ xla_cc_test( "//xla/ffi:type_id_registry", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:test", diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index a677c4e355ee0c..8db0d46fe27263 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -34,8 +34,8 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/xla/ffi/call_frame_test.cc b/xla/ffi/call_frame_test.cc index 2937b53bb5d997..7b767bfb841af8 100644 --- a/xla/ffi/call_frame_test.cc +++ b/xla/ffi/call_frame_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/ffi/api/c_api.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/xla/ffi/execution_context_test.cc b/xla/ffi/execution_context_test.cc index 7a2a1b33992ede..6a5cdfa40b07b6 100644 --- a/xla/ffi/execution_context_test.cc +++ b/xla/ffi/execution_context_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/ffi/type_id_registry.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/ffi/execution_state_test.cc b/xla/ffi/execution_state_test.cc index d8929246ca0161..dd8244f00183ff 100644 --- a/xla/ffi/execution_state_test.cc +++ b/xla/ffi/execution_state_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/ffi/ffi_test.cc b/xla/ffi/ffi_test.cc index 63f5dbf30e20d2..ab8d200d1a08a0 100644 --- a/xla/ffi/ffi_test.cc +++ b/xla/ffi/ffi_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "xla/ffi/ffi_api.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/hlo/experimental/auto_sharding/BUILD b/xla/hlo/experimental/auto_sharding/BUILD index 4ad33554f98adf..96835b67a06d10 100644 --- a/xla/hlo/experimental/auto_sharding/BUILD +++ b/xla/hlo/experimental/auto_sharding/BUILD @@ -358,6 +358,7 @@ xla_cc_test( "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -365,7 +366,6 @@ xla_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 4f9d533a4edcec..cb595afaf93569 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -46,7 +46,7 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/service/hlo_value.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace op = xla::testing::opcode_matchers; diff --git a/xla/hlo/ir/BUILD b/xla/hlo/ir/BUILD index 203223870f6e79..d26ee54921027a 100644 --- a/xla/hlo/ir/BUILD +++ b/xla/hlo/ir/BUILD @@ -136,9 +136,9 @@ xla_cc_test( deps = [ ":backend_config", "//xla/service/gpu:backend_configs_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], diff --git a/xla/hlo/ir/backend_config_test.cc b/xla/hlo/ir/backend_config_test.cc index 5ffe3ae98b8d6c..09b56347e450ed 100644 --- a/xla/hlo/ir/backend_config_test.cc +++ b/xla/hlo/ir/backend_config_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/hlo/transforms/BUILD b/xla/hlo/transforms/BUILD index 65d2368f0d0453..70761e09c74f6e 100644 --- a/xla/hlo/transforms/BUILD +++ b/xla/hlo/transforms/BUILD @@ -49,7 +49,7 @@ xla_cc_test( "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/hlo/transforms/hlo_constant_splitter_test.cc b/xla/hlo/transforms/hlo_constant_splitter_test.cc index 58a25ef26d0aac..c7ebf8459502e8 100644 --- a/xla/hlo/transforms/hlo_constant_splitter_test.cc +++ b/xla/hlo/transforms/hlo_constant_splitter_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/hlo/utils/BUILD b/xla/hlo/utils/BUILD index 94a0f473f8a0db..2f97f4c45dbd43 100644 --- a/xla/hlo/utils/BUILD +++ b/xla/hlo/utils/BUILD @@ -54,8 +54,8 @@ xla_cc_test( "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/hlo/utils/hlo_live_range_test.cc b/xla/hlo/utils/hlo_live_range_test.cc index b4155b103cfc11..64e4ab5ee37d62 100644 --- a/xla/hlo/utils/hlo_live_range_test.cc +++ b/xla/hlo/utils/hlo_live_range_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_value.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index 241baf6e9eb84f..893820780276fe 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/literal_util.h" #include "xla/test_helpers.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/ml_dtypes.h" namespace xla { diff --git a/xla/literal_test.cc b/xla/literal_test.cc index cddd1212bfee20..36a3c263e27c36 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -46,10 +46,10 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/test.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/macros.h" diff --git a/xla/mlir/utils/BUILD b/xla/mlir/utils/BUILD index 2a1aaf5638bb5e..9f8eac47418fd5 100644 --- a/xla/mlir/utils/BUILD +++ b/xla/mlir/utils/BUILD @@ -30,11 +30,11 @@ cc_test( srcs = ["error_util_test.cc"], deps = [ ":error_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", "@tsl//tsl/platform:test_main", ], diff --git a/xla/mlir/utils/error_util_test.cc b/xla/mlir/utils/error_util_test.cc index f325cd070f7f52..23f214f9658b26 100644 --- a/xla/mlir/utils/error_util_test.cc +++ b/xla/mlir/utils/error_util_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace mlir { diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index f9cd50174e7bb3..34948f2e3cb888 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -185,8 +185,8 @@ xla_cc_test( ":pjrt_api", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_wrapper_impl", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/protobuf:error_codes_proto_impl_cc", ], @@ -556,10 +556,10 @@ xla_cc_test( "//xla/service:cpu_plugin", "//xla/service:platform_util", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ], @@ -831,9 +831,9 @@ xla_cc_test( "//xla/client:xla_builder", "//xla/pjrt/c:pjrt_c_api_cpu_internal", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -896,9 +896,9 @@ xla_cc_test( ":host_callback", ":pjrt_client", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", ], ) diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD index dad4ceb9887634..8738cb6ec5cbfa 100644 --- a/xla/pjrt/c/BUILD +++ b/xla/pjrt/c/BUILD @@ -408,13 +408,13 @@ xla_cc_test( "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt/distributed:in_memory_key_value_store", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", diff --git a/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/xla/pjrt/c/pjrt_c_api_helpers_test.cc index d6e240d8c5e96b..8d0a51a48bc840 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index eff9c1f550122c..d173f40e48ba06 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -242,11 +242,11 @@ xla_cc_test( "//xla/service:hlo_proto_cc", "//xla/tests:literal_test_util", "//xla/tests:test_utils", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status_matchers", @@ -317,11 +317,11 @@ xla_cc_test( "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/service/cpu:collectives_interface", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@gloo//:transport_tcp", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", diff --git a/xla/pjrt/cpu/cpu_client_test.cc b/xla/pjrt/cpu/cpu_client_test.cc index 641222e91ce21a..52ca154759e81a 100644 --- a/xla/pjrt/cpu/cpu_client_test.cc +++ b/xla/pjrt/cpu/cpu_client_test.cc @@ -48,8 +48,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" diff --git a/xla/pjrt/cpu/gloo_collectives_test.cc b/xla/pjrt/cpu/gloo_collectives_test.cc index 9301cf0a23d094..0b2fd8d3c66e82 100644 --- a/xla/pjrt/cpu/gloo_collectives_test.cc +++ b/xla/pjrt/cpu/gloo_collectives_test.cc @@ -34,8 +34,8 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/pjrt/distributed/BUILD b/xla/pjrt/distributed/BUILD index 6faf2aadd35bac..afacb16f6680af 100644 --- a/xla/pjrt/distributed/BUILD +++ b/xla/pjrt/distributed/BUILD @@ -50,9 +50,9 @@ xla_cc_test( ":protocol_proto_cc", ":topology_util", "//xla:test_helpers", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -145,13 +145,13 @@ xla_cc_test( "//xla:protobuf_util", "//xla:status_macros", "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", diff --git a/xla/pjrt/distributed/client_server_test.cc b/xla/pjrt/distributed/client_server_test.cc index 5ccbf232dd07a6..8c04e7608ec41d 100644 --- a/xla/pjrt/distributed/client_server_test.cc +++ b/xla/pjrt/distributed/client_server_test.cc @@ -45,7 +45,7 @@ limitations under the License. #include "xla/protobuf_util.h" #include "xla/status_macros.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/pjrt/distributed/topology_util_test.cc b/xla/pjrt/distributed/topology_util_test.cc index aaf859c658e157..193dae87ca1a0b 100644 --- a/xla/pjrt/distributed/topology_util_test.cc +++ b/xla/pjrt/distributed/topology_util_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/test_helpers.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD index b7e06b628fc49c..2560c88c10d855 100644 --- a/xla/pjrt/gpu/BUILD +++ b/xla/pjrt/gpu/BUILD @@ -168,13 +168,13 @@ xla_cc_test( "//xla/service:platform_util", "//xla/stream_executor", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index a664a11352a023..e034e83efd5893 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -57,8 +57,8 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/test.h" #include "xla/tests/literal_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" diff --git a/xla/pjrt/host_callback_test.cc b/xla/pjrt/host_callback_test.cc index f443b9f8bbb524..ef9d5d9ec70c59 100644 --- a/xla/pjrt/host_callback_test.cc +++ b/xla/pjrt/host_callback_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/pjrt/pjrt_client.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/pjrt/pjrt_api_test.cc b/xla/pjrt/pjrt_api_test.cc index b6e13ca5d14e2c..8ee9e49451a99e 100644 --- a/xla/pjrt/pjrt_api_test.cc +++ b/xla/pjrt/pjrt_api_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/protobuf/error_codes.pb.h" namespace { diff --git a/xla/pjrt/pjrt_c_api_client_test.cc b/xla/pjrt/pjrt_c_api_client_test.cc index 188e159419a2e3..4dbdd5d03af4cb 100644 --- a/xla/pjrt/pjrt_c_api_client_test.cc +++ b/xla/pjrt/pjrt_c_api_client_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/pjrt/pjrt_stream_executor_client_test.cc b/xla/pjrt/pjrt_stream_executor_client_test.cc index d34d5c3c54740f..19f1c150ef232b 100644 --- a/xla/pjrt/pjrt_stream_executor_client_test.cc +++ b/xla/pjrt/pjrt_stream_executor_client_test.cc @@ -34,8 +34,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/xla/python/ifrt/BUILD b/xla/python/ifrt/BUILD index d140b795f9a944..0e548ee5371836 100644 --- a/xla/python/ifrt/BUILD +++ b/xla/python/ifrt/BUILD @@ -169,10 +169,10 @@ xla_cc_test( srcs = ["future_test.cc"], deps = [ ":ifrt", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", ], ) @@ -249,11 +249,11 @@ cc_library( deps = [ ":ifrt", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -295,10 +295,10 @@ cc_library( ":ifrt", ":test_util", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -350,8 +350,8 @@ cc_library( ":ifrt", ":test_util", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -576,11 +576,11 @@ cc_library( ":test_util", "//xla:status_macros", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -640,8 +640,8 @@ xla_cc_test( ":plugin_program_serdes", ":serdes", ":serdes_proto_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/protobuf:error_codes_proto_impl_cc", "@tsl//tsl/protobuf:status_proto_cc", @@ -700,11 +700,11 @@ xla_cc_test( ":ifrt", ":program_serdes", ":serdes", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/python/ifrt/array_impl_test_lib.cc b/xla/python/ifrt/array_impl_test_lib.cc index 6d5c073c5e29ed..d5f83c8c070eb5 100644 --- a/xla/python/ifrt/array_impl_test_lib.cc +++ b/xla/python/ifrt/array_impl_test_lib.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/python/ifrt/test_util.h" #include "xla/python/ifrt/value.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/python/ifrt/custom_call_program_serdes_test.cc b/xla/python/ifrt/custom_call_program_serdes_test.cc index 31a259378695cc..332314a3b3d93d 100644 --- a/xla/python/ifrt/custom_call_program_serdes_test.cc +++ b/xla/python/ifrt/custom_call_program_serdes_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/python/ifrt/future_test.cc b/xla/python/ifrt/future_test.cc index 650f4849c0db1f..808d9a4981494a 100644 --- a/xla/python/ifrt/future_test.cc +++ b/xla/python/ifrt/future_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/types/span.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" namespace xla { diff --git a/xla/python/ifrt/ir/tests/BUILD b/xla/python/ifrt/ir/tests/BUILD index b872c2baa16c35..2f29fac5e8018d 100644 --- a/xla/python/ifrt/ir/tests/BUILD +++ b/xla/python/ifrt/ir/tests/BUILD @@ -97,10 +97,10 @@ cc_library( "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc b/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc index dc5330728d067c..95d04b081ee1be 100644 --- a/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc +++ b/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc @@ -37,7 +37,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/service/computation_placer.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/python/ifrt/plugin_program_serdes_test.cc b/xla/python/ifrt/plugin_program_serdes_test.cc index 31dca456bd0ea4..4edfae40571cae 100644 --- a/xla/python/ifrt/plugin_program_serdes_test.cc +++ b/xla/python/ifrt/plugin_program_serdes_test.cc @@ -18,7 +18,7 @@ #include "xla/python/ifrt/plugin_program.h" #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/serdes.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/protobuf/error_codes.pb.h" #include "tsl/protobuf/status.pb.h" diff --git a/xla/python/ifrt/remap_impl_test_lib.cc b/xla/python/ifrt/remap_impl_test_lib.cc index a55d97d13998e4..85822b51c24e45 100644 --- a/xla/python/ifrt/remap_impl_test_lib.cc +++ b/xla/python/ifrt/remap_impl_test_lib.cc @@ -36,7 +36,7 @@ limitations under the License. #include "xla/python/ifrt/test_util.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/python/ifrt/support/BUILD b/xla/python/ifrt/support/BUILD index 77e076d6b3ba28..e36358f8da1c31 100644 --- a/xla/python/ifrt/support/BUILD +++ b/xla/python/ifrt/support/BUILD @@ -37,13 +37,13 @@ xla_cc_test( "//xla/python/ifrt:mock", "//xla/python/ifrt:test_util", "//xla/python/ifrt/ir:sharding_param", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", ], diff --git a/xla/python/ifrt/support/sharding_conversions_test.cc b/xla/python/ifrt/support/sharding_conversions_test.cc index da1b26c6bf9555..7973ec0f4abe6a 100644 --- a/xla/python/ifrt/support/sharding_conversions_test.cc +++ b/xla/python/ifrt/support/sharding_conversions_test.cc @@ -38,8 +38,8 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/test_util.h" #include "xla/shape.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/xla/python/ifrt/test_util.h b/xla/python/ifrt/test_util.h index 45e1258e8ec0e2..cd7ffc73824806 100644 --- a/xla/python/ifrt/test_util.h +++ b/xla/python/ifrt/test_util.h @@ -1,4 +1,4 @@ -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" /* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/xla/python/ifrt/tuple_impl_test_lib.cc b/xla/python/ifrt/tuple_impl_test_lib.cc index 643421076f3a5f..5a29e6e7587f4c 100644 --- a/xla/python/ifrt/tuple_impl_test_lib.cc +++ b/xla/python/ifrt/tuple_impl_test_lib.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/python/ifrt/test_util.h" #include "xla/python/ifrt/tuple.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/python/ifrt_proxy/server/BUILD b/xla/python/ifrt_proxy/server/BUILD index d74970ed909e41..ae8f52238db2d6 100644 --- a/xla/python/ifrt_proxy/server/BUILD +++ b/xla/python/ifrt_proxy/server/BUILD @@ -183,6 +183,7 @@ ifrt_proxy_cc_test( "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", @@ -195,7 +196,6 @@ ifrt_proxy_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:protobuf", diff --git a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 79da57ce2c3bf3..df3e72d53da438 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -67,8 +67,8 @@ #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep diff --git a/xla/python/pjrt_ifrt/BUILD b/xla/python/pjrt_ifrt/BUILD index 62c118aa67dee2..89eca2fbf8b6e3 100644 --- a/xla/python/pjrt_ifrt/BUILD +++ b/xla/python/pjrt_ifrt/BUILD @@ -140,11 +140,11 @@ cc_library( "//xla/python/ifrt", "//xla/python/ifrt:test_util", "//xla/python/ifrt/hlo:hlo_program", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -337,6 +337,7 @@ xla_cc_test( "//xla/python/ifrt", "//xla/python/ifrt:test_util", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -345,7 +346,6 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/python/pjrt_ifrt/basic_string_array_test.cc b/xla/python/pjrt_ifrt/basic_string_array_test.cc index a0d21a4cf11307..108e5a9f982760 100644 --- a/xla/python/pjrt_ifrt/basic_string_array_test.cc +++ b/xla/python/pjrt_ifrt/basic_string_array_test.cc @@ -44,7 +44,7 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/test_util.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc b/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc index 04da5007591f4f..4a3bc4197c766c 100644 --- a/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc +++ b/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc @@ -40,7 +40,7 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/test_util.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/BUILD b/xla/service/BUILD index 32dae31309906a..8fa7f829b73b42 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -142,7 +142,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -292,12 +292,12 @@ xla_cc_test( "//xla/service/gpu:gpu_reduce_scatter_creator", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", ], @@ -579,7 +579,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -668,10 +668,10 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -722,8 +722,8 @@ xla_cc_test( ":hlo_parser", "//xla:xla_proto_cc", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", @@ -920,8 +920,8 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -956,7 +956,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -1101,8 +1101,8 @@ xla_cc_test( "//xla/service/gpu:backend_configs_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -1158,11 +1158,11 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], @@ -1227,7 +1227,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -1267,8 +1267,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -1289,7 +1289,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -1580,11 +1580,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", @@ -1799,8 +1799,8 @@ xla_test( "//xla/stream_executor/gpu:gpu_init", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -2003,9 +2003,9 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -2049,9 +2049,9 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -2069,7 +2069,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2146,10 +2146,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -2167,9 +2167,9 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -2225,11 +2225,11 @@ xla_cc_test( "//xla/service/heap_simulator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", @@ -2505,7 +2505,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2547,7 +2547,7 @@ xla_cc_test( "//xla:types", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2625,7 +2625,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2822,13 +2822,13 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -3029,7 +3029,7 @@ xla_cc_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -3227,7 +3227,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -3365,7 +3365,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", ], ) @@ -3414,8 +3414,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", ], ) @@ -3741,10 +3741,10 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -3932,8 +3932,8 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -4038,8 +4038,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -4199,8 +4199,8 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -4275,13 +4275,13 @@ xla_test( "//xla/tests:llvm_irgen_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -4307,7 +4307,7 @@ xla_cc_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_benchmark", ], @@ -4326,7 +4326,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -4456,9 +4456,9 @@ xla_cc_test( "//xla/stream_executor/host:host_platform_id", "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -4581,10 +4581,10 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/lib/strings:proto_serialization", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", @@ -4719,11 +4719,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -4822,7 +4822,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -4932,7 +4932,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", ], @@ -5188,7 +5188,7 @@ xla_cc_test( ":memory_space_propagation", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -5282,6 +5282,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/log:scoped_mock_log", "@com_google_absl//absl/status", @@ -5289,7 +5290,6 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -5319,8 +5319,8 @@ xla_cc_test( ":hlo_verifier", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -5393,10 +5393,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -5418,9 +5418,9 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -5436,7 +5436,7 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -5462,11 +5462,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -5531,7 +5531,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -5691,7 +5691,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -5748,7 +5748,7 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -6211,7 +6211,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -6241,7 +6241,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -6376,11 +6376,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -6425,11 +6425,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -6482,12 +6482,12 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -6521,10 +6521,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -6606,9 +6606,9 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -6649,9 +6649,9 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -6697,8 +6697,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -6798,7 +6798,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -6959,9 +6959,9 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -7314,7 +7314,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -7417,10 +7417,10 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -7437,10 +7437,10 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -7479,7 +7479,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -7675,9 +7675,9 @@ xla_cc_test( ":mapped_ptr_container_sorter", "//xla:test", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -7792,8 +7792,8 @@ xla_cc_test( "//xla:test", "//xla:xla_proto_cc", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:protobuf", ], @@ -8205,9 +8205,9 @@ xla_cc_test( "//xla/client:executable_build_options", "//xla/client:local_client", "//xla/service/cpu:cpu_compiler", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", @@ -8233,8 +8233,8 @@ xla_cc_test( "//xla/client:local_client", "//xla/service/cpu:cpu_compiler", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", @@ -8274,7 +8274,7 @@ xla_cc_test( "//xla/client:executable_build_options", "//xla:literal", "//xla:shape_util", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", @@ -8327,8 +8327,8 @@ xla_cc_test( ":gpu_compilation_environment", "//xla:parse_flags_from_env", "//xla:xla_proto_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status_matchers", diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index 5b8e4db491c13b..d9e425281182db 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -58,10 +58,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/all_reduce_simplifier_test.cc b/xla/service/all_reduce_simplifier_test.cc index e78881a0c19292..35f5955076ad7e 100644 --- a/xla/service/all_reduce_simplifier_test.cc +++ b/xla/service/all_reduce_simplifier_test.cc @@ -28,9 +28,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/window_util.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/all_reduce_splitter_test.cc b/xla/service/all_reduce_splitter_test.cc index 6725a50bc35c6f..3902a97c439724 100644 --- a/xla/service/all_reduce_splitter_test.cc +++ b/xla/service/all_reduce_splitter_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/service/hlo_pass_pipeline.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/async_collective_creator_test.cc b/xla/service/async_collective_creator_test.cc index 8c9b574003da9a..75556260cf2e14 100644 --- a/xla/service/async_collective_creator_test.cc +++ b/xla/service/async_collective_creator_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/bitcast_dtypes_expander_test.cc b/xla/service/bitcast_dtypes_expander_test.cc index b145e8ceb7b5fe..a5dc3b882446cc 100644 --- a/xla/service/bitcast_dtypes_expander_test.cc +++ b/xla/service/bitcast_dtypes_expander_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/buffer_assignment_test.cc b/xla/service/buffer_assignment_test.cc index a11b86ca357043..04238c4fd39f5a 100644 --- a/xla/service/buffer_assignment_test.cc +++ b/xla/service/buffer_assignment_test.cc @@ -46,9 +46,9 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/call_graph_test.cc b/xla/service/call_graph_test.cc index dfa7d28f06ab1d..a619cd5ffe6e28 100644 --- a/xla/service/call_graph_test.cc +++ b/xla/service/call_graph_test.cc @@ -36,9 +36,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/call_inliner_test.cc b/xla/service/call_inliner_test.cc index 4248c012444803..da73fe645fa058 100644 --- a/xla/service/call_inliner_test.cc +++ b/xla/service/call_inliner_test.cc @@ -31,9 +31,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace op = xla::testing::opcode_matchers; diff --git a/xla/service/collective_ops_utils_test.cc b/xla/service/collective_ops_utils_test.cc index 64ec33866d2b32..f1a7ab1f4561f0 100644 --- a/xla/service/collective_ops_utils_test.cc +++ b/xla/service/collective_ops_utils_test.cc @@ -32,8 +32,8 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/collective_quantizer_test.cc b/xla/service/collective_quantizer_test.cc index a095e3ef4e19a1..fff673e4707b7a 100644 --- a/xla/service/collective_quantizer_test.cc +++ b/xla/service/collective_quantizer_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/collective_transformation_reorderer_test.cc b/xla/service/collective_transformation_reorderer_test.cc index 3721406e64901a..73f185e1caf73f 100644 --- a/xla/service/collective_transformation_reorderer_test.cc +++ b/xla/service/collective_transformation_reorderer_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/compilation_environments_test.cc b/xla/service/compilation_environments_test.cc index b3cd2946cf06f4..35058aefd45994 100644 --- a/xla/service/compilation_environments_test.cc +++ b/xla/service/compilation_environments_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/service/test_compilation_environment.pb.h" #include "xla/test.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/casts.h" #include "tsl/platform/protobuf.h" diff --git a/xla/service/compiler_test.cc b/xla/service/compiler_test.cc index c2743c15aff889..951330e94d375e 100644 --- a/xla/service/compiler_test.cc +++ b/xla/service/compiler_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/conditional_canonicalizer_test.cc b/xla/service/conditional_canonicalizer_test.cc index 3d5e1e976da0d1..beba61a5a67832 100644 --- a/xla/service/conditional_canonicalizer_test.cc +++ b/xla/service/conditional_canonicalizer_test.cc @@ -25,9 +25,9 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/conditional_code_motion_test.cc b/xla/service/conditional_code_motion_test.cc index fcfe91d7a21dfa..0a3d74327dd522 100644 --- a/xla/service/conditional_code_motion_test.cc +++ b/xla/service/conditional_code_motion_test.cc @@ -29,9 +29,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace xla { diff --git a/xla/service/conditional_simplifier_test.cc b/xla/service/conditional_simplifier_test.cc index 083ef03453d67f..24a7c0a68045b0 100644 --- a/xla/service/conditional_simplifier_test.cc +++ b/xla/service/conditional_simplifier_test.cc @@ -26,9 +26,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace xla { diff --git a/xla/service/convert_async_collectives_to_sync_test.cc b/xla/service/convert_async_collectives_to_sync_test.cc index c155c2ff21397f..a404f03e5301cf 100644 --- a/xla/service/convert_async_collectives_to_sync_test.cc +++ b/xla/service/convert_async_collectives_to_sync_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index a6f4e82e433305..34d8ad907ac131 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1400,7 +1400,7 @@ xla_cc_test( ":cpu_runtime", "//xla:shape_util", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", @@ -1586,8 +1586,8 @@ xla_cc_test( "//xla/service/cpu:target_machine_features", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/cpu/runtime/BUILD b/xla/service/cpu/runtime/BUILD index 5ffa84491a91a2..b926b8539f7aaa 100644 --- a/xla/service/cpu/runtime/BUILD +++ b/xla/service/cpu/runtime/BUILD @@ -53,8 +53,8 @@ xla_cc_test( "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", "//xla/stream_executor:device_memory", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -112,8 +112,8 @@ xla_cc_test( "//xla:executable_run_options", "//xla/service/cpu:collectives_interface", "//xla/service/cpu:cpu_executable_run_options", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -159,11 +159,11 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", @@ -229,8 +229,8 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -538,7 +538,7 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -663,8 +663,8 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -754,8 +754,8 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -806,9 +806,9 @@ xla_cc_test( "//xla/stream_executor", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -908,9 +908,9 @@ xla_cc_test( "//xla/stream_executor", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -958,9 +958,9 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@eigen_archive//:eigen3", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/service/cpu/tests/BUILD b/xla/service/cpu/tests/BUILD index a0921e4344ea03..7f8076e52608cc 100644 --- a/xla/service/cpu/tests/BUILD +++ b/xla/service/cpu/tests/BUILD @@ -293,8 +293,8 @@ xla_cc_test( "//xla/service:hlo_module_config", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:test_header_helper", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", ], diff --git a/xla/service/cpu/tests/cpu_spmd_compile_test.cc b/xla/service/cpu/tests/cpu_spmd_compile_test.cc index 6dc8cb9f7bb089..077e7eef9cc4f7 100644 --- a/xla/service/cpu/tests/cpu_spmd_compile_test.cc +++ b/xla/service/cpu/tests/cpu_spmd_compile_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/service/cpu/tests/cpu_codegen_test.h" #include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/cpu/xfeed_manager_test.cc b/xla/service/cpu/xfeed_manager_test.cc index 5c6be64e6a7dd1..5b682e207386d8 100644 --- a/xla/service/cpu/xfeed_manager_test.cc +++ b/xla/service/cpu/xfeed_manager_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/service/cpu/cpu_runtime.h" #include "xla/shape_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" diff --git a/xla/service/cpu_gpu_shape_verifier_test.cc b/xla/service/cpu_gpu_shape_verifier_test.cc index 277143e89b7189..7bb40d701d9667 100644 --- a/xla/service/cpu_gpu_shape_verifier_test.cc +++ b/xla/service/cpu_gpu_shape_verifier_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/dfs_hlo_visitor_with_default_test.cc b/xla/service/dfs_hlo_visitor_with_default_test.cc index df05bef5a4397f..2fe22688ee2018 100644 --- a/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/dot_merger_test.cc b/xla/service/dot_merger_test.cc index 97b9da0d0c279d..786970e7904f96 100644 --- a/xla/service/dot_merger_test.cc +++ b/xla/service/dot_merger_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/dump_test.cc b/xla/service/dump_test.cc index 78adadf3ea8d1b..6df547d96fdfce 100644 --- a/xla/service/dump_test.cc +++ b/xla/service/dump_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "absl/strings/match.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/dynamic_dimension_inference_test.cc b/xla/service/dynamic_dimension_inference_test.cc index 502f0079000948..9dc9de161aa4bd 100644 --- a/xla/service/dynamic_dimension_inference_test.cc +++ b/xla/service/dynamic_dimension_inference_test.cc @@ -30,8 +30,8 @@ limitations under the License. #include "xla/test_helpers.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" diff --git a/xla/service/dynamic_dimension_simplifier_test.cc b/xla/service/dynamic_dimension_simplifier_test.cc index 94e48eca1104e3..2131c6c002a3e5 100644 --- a/xla/service/dynamic_dimension_simplifier_test.cc +++ b/xla/service/dynamic_dimension_simplifier_test.cc @@ -37,10 +37,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/dynamic_padder_test.cc b/xla/service/dynamic_padder_test.cc index 3e3efa1a1832c6..972bc38ae8c40b 100644 --- a/xla/service/dynamic_padder_test.cc +++ b/xla/service/dynamic_padder_test.cc @@ -53,9 +53,9 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/llvm_irgen_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/dynamic_parameter_binding_test.cc b/xla/service/dynamic_parameter_binding_test.cc index 11dfbcdbec9617..94eaf4e5166bce 100644 --- a/xla/service/dynamic_parameter_binding_test.cc +++ b/xla/service/dynamic_parameter_binding_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/flatten_call_graph_test.cc b/xla/service/flatten_call_graph_test.cc index 57498209c756c1..0a8be831355a5e 100644 --- a/xla/service/flatten_call_graph_test.cc +++ b/xla/service/flatten_call_graph_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/generic_transfer_manager_test.cc b/xla/service/generic_transfer_manager_test.cc index 41ea92d46a0385..eb8cb7afa85004 100644 --- a/xla/service/generic_transfer_manager_test.cc +++ b/xla/service/generic_transfer_manager_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/literal_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 58b83a2ca7922e..5afa5ef0e9a5f1 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -170,6 +170,7 @@ xla_test( "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:client_library_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -177,7 +178,6 @@ xla_test( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ] + if_cuda_is_configured([ @@ -597,12 +597,12 @@ xla_test( "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tools:hlo_decomposer_lib", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:path", @@ -1174,10 +1174,10 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", @@ -1385,8 +1385,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:platform", "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -1597,10 +1597,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -1689,8 +1689,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:platform", "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -2561,11 +2561,11 @@ xla_cc_test( "//xla/stream_executor:dnn", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:statusor", @@ -2896,8 +2896,8 @@ xla_cc_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ], @@ -3417,12 +3417,12 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", @@ -3452,10 +3452,10 @@ xla_test( "//xla/service/gpu:stream_attribute_annotator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -3636,10 +3636,10 @@ xla_test( "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], @@ -4544,12 +4544,12 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ] + if_cuda_is_configured([ @@ -4608,8 +4608,8 @@ xla_test( "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", "//xla/tests:filecheck", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", @@ -4714,10 +4714,10 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ] + if_cuda_is_configured([ @@ -4887,7 +4887,7 @@ xla_cc_test( ":alias_passthrough_params", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", ], ) @@ -4941,9 +4941,9 @@ xla_test( "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -5317,9 +5317,9 @@ xla_cc_test( "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -5559,9 +5559,9 @@ xla_cc_test( ":executable_proto_cc", ":kernel_reuse_cache", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:test", ], @@ -5682,6 +5682,7 @@ xla_cc_test( "//xla/stream_executor/host:host_platform", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -5689,7 +5690,6 @@ xla_cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", @@ -6020,11 +6020,11 @@ xla_test( "//xla/service:platform_util", "//xla/stream_executor:platform", "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -6129,11 +6129,11 @@ xla_cc_test( "//xla/service:profile_guided_latency_estimator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], diff --git a/xla/service/gpu/alias_passthrough_params_test.cc b/xla/service/gpu/alias_passthrough_params_test.cc index d8141232ebbd3f..2c09daff14a326 100644 --- a/xla/service/gpu/alias_passthrough_params_test.cc +++ b/xla/service/gpu/alias_passthrough_params_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "xla/service/gpu/alias_passthrough_params.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/service/gpu/autotuner_util_test.cc b/xla/service/gpu/autotuner_util_test.cc index 37fb56ed67fb83..69c139549690a6 100644 --- a/xla/service/gpu/autotuner_util_test.cc +++ b/xla/service/gpu/autotuner_util_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep diff --git a/xla/service/gpu/command_buffer_scheduling_test.cc b/xla/service/gpu/command_buffer_scheduling_test.cc index bda31a05980b19..3a46193983304c 100644 --- a/xla/service/gpu/command_buffer_scheduling_test.cc +++ b/xla/service/gpu/command_buffer_scheduling_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/gpu/conv_algorithm_picker_test.cc b/xla/service/gpu/conv_algorithm_picker_test.cc index d9a3a691da0565..aa7c5e2f0e97b6 100644 --- a/xla/service/gpu/conv_algorithm_picker_test.cc +++ b/xla/service/gpu/conv_algorithm_picker_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index ac03122baebd03..0a58ecf223c620 100644 --- a/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -62,7 +62,7 @@ limitations under the License. #include "xla/service/reshape_mover.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc index 2cf88b01a8fe8b..897136b480b85d 100644 --- a/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc +++ b/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc @@ -44,9 +44,9 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA diff --git a/xla/service/gpu/cudnn_simplify_padding_test.cc b/xla/service/gpu/cudnn_simplify_padding_test.cc index 4cd9b72ef8ea65..a0e527a3eeb881 100644 --- a/xla/service/gpu/cudnn_simplify_padding_test.cc +++ b/xla/service/gpu/cudnn_simplify_padding_test.cc @@ -38,8 +38,8 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/gpu/custom_call_test.cc b/xla/service/gpu/custom_call_test.cc index 87050db0129d56..4331d6efba429c 100644 --- a/xla/service/gpu/custom_call_test.cc +++ b/xla/service/gpu/custom_call_test.cc @@ -57,7 +57,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA diff --git a/xla/service/gpu/fusions/BUILD b/xla/service/gpu/fusions/BUILD index f9edabbad5df6f..36616c3c9cd96b 100644 --- a/xla/service/gpu/fusions/BUILD +++ b/xla/service/gpu/fusions/BUILD @@ -87,8 +87,8 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -390,8 +390,8 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -439,8 +439,8 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -491,8 +491,8 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) @@ -707,10 +707,10 @@ xla_test( "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", @@ -898,11 +898,11 @@ xla_test( "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -988,8 +988,8 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/gpu/fusions/concatenate_mlir_test.cc b/xla/service/gpu/fusions/concatenate_mlir_test.cc index c0637cbe12dc74..92aff949ace0b6 100644 --- a/xla/service/gpu/fusions/concatenate_mlir_test.cc +++ b/xla/service/gpu/fusions/concatenate_mlir_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/fusions/cudnn_test.cc b/xla/service/gpu/fusions/cudnn_test.cc index 9e9e1ce7560500..caa24b1c1b3d68 100644 --- a/xla/service/gpu/fusions/cudnn_test.cc +++ b/xla/service/gpu/fusions/cudnn_test.cc @@ -43,9 +43,9 @@ limitations under the License. #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc index b68a95e9516bfd..b0da3ef5c04532 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/fusions/loop_mlir_test.cc b/xla/service/gpu/fusions/loop_mlir_test.cc index 08dcb4df490e54..357ef652985b43 100644 --- a/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/xla/service/gpu/fusions/loop_mlir_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/fusions/mlir/BUILD b/xla/service/gpu/fusions/mlir/BUILD index 9b603576488db4..a74ca8ba960610 100644 --- a/xla/service/gpu/fusions/mlir/BUILD +++ b/xla/service/gpu/fusions/mlir/BUILD @@ -126,6 +126,7 @@ xla_cc_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", @@ -140,7 +141,6 @@ xla_cc_test( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 6a27e548ca932f..d7bbbb0bd34c3c 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -47,7 +47,7 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/gpu/fusions/reduction_mlir_test.cc b/xla/service/gpu/fusions/reduction_mlir_test.cc index 6ba7431530309e..761ecb4f31fe59 100644 --- a/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/fusions/scatter_mlir_test.cc b/xla/service/gpu/fusions/scatter_mlir_test.cc index 869d2335001825..6b8d013a81f735 100644 --- a/xla/service/gpu/fusions/scatter_mlir_test.cc +++ b/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/fusions/transpose_mlir_test.cc b/xla/service/gpu/fusions/transpose_mlir_test.cc index 1861672a82279d..eb71bb7f110db9 100644 --- a/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index 4fd95ea732c9ca..305e9f0f081485 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -212,6 +212,7 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -219,7 +220,6 @@ xla_test( "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:path", @@ -257,13 +257,13 @@ xla_test( "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:path", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", @@ -322,6 +322,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -332,7 +333,6 @@ xla_cc_test( "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Support", "@triton//:TritonDialects", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:logging", ], ) @@ -464,11 +464,11 @@ xla_test( "//xla/service/gpu:triton_fusion_analysis", "//xla/service/gpu/model:tiled_hlo_computation", "//xla/stream_executor:device_description", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", ], diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 9a2d2d6dbb7520..c162c34f18449e 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -51,8 +51,8 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index 9ca1b90100e0a9..8f07bbace4b3ad 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -38,8 +38,8 @@ limitations under the License. #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc index 44611bda590dfa..a327c0c1c74c88 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc @@ -51,7 +51,7 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" diff --git a/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc b/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc index 89b1e8d1bc297b..3eefd362564f71 100644 --- a/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc +++ b/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc @@ -40,9 +40,9 @@ limitations under the License. #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/gpu/gemm_algorithm_picker_test.cc b/xla/service/gpu/gemm_algorithm_picker_test.cc index e387aad44ef341..8049aa1278792c 100644 --- a/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/protobuf/dnn.pb.h" diff --git a/xla/service/gpu/gemm_fusion_autotuner_test.cc b/xla/service/gpu/gemm_fusion_autotuner_test.cc index 8cb7e8dc87e229..32c2c96d595a44 100644 --- a/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -56,9 +56,9 @@ limitations under the License. #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" #include "xla/tools/hlo_decomposer.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 7e9e3a419890d1..b74d77c9da31aa 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -46,8 +46,8 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/casts.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc b/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc index 4daeb62905f8a2..03f18bd3c5eb6d 100644 --- a/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc +++ b/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc @@ -25,8 +25,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index 590adffffbd077..24622c8d685265 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/service/profile_guided_latency_estimator.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/gpu/gpu_offloading_test.cc b/xla/service/gpu/gpu_offloading_test.cc index 3099b957575e6f..928011cb4b76a1 100644 --- a/xla/service/gpu/gpu_offloading_test.cc +++ b/xla/service/gpu/gpu_offloading_test.cc @@ -38,8 +38,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/hlo_fusion_stats_test.cc b/xla/service/gpu/hlo_fusion_stats_test.cc index 0a19b213922b42..c2d33da7fcf408 100644 --- a/xla/service/gpu/hlo_fusion_stats_test.cc +++ b/xla/service/gpu/hlo_fusion_stats_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/match.h" #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/horizontal_loop_fusion_test.cc b/xla/service/gpu/horizontal_loop_fusion_test.cc index 935c21c6e23fed..4045183dcf0867 100644 --- a/xla/service/gpu/horizontal_loop_fusion_test.cc +++ b/xla/service/gpu/horizontal_loop_fusion_test.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/kernel_reuse_cache_test.cc b/xla/service/gpu/kernel_reuse_cache_test.cc index 1d4ea628f5b832..3f32225a72759c 100644 --- a/xla/service/gpu/kernel_reuse_cache_test.cc +++ b/xla/service/gpu/kernel_reuse_cache_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "absl/log/check.h" #include "xla/service/gpu/executable.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" namespace xla { diff --git a/xla/service/gpu/kernels/BUILD b/xla/service/gpu/kernels/BUILD index 532c908de0791c..d4299916ba1b95 100644 --- a/xla/service/gpu/kernels/BUILD +++ b/xla/service/gpu/kernels/BUILD @@ -234,10 +234,10 @@ xla_test( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -283,7 +283,7 @@ xla_test( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index e53a1166a1f5db..4c0a5869cb2b2e 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -26,8 +26,8 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/kernels/topk_custom_kernel_test.cc b/xla/service/gpu/kernels/topk_custom_kernel_test.cc index 974cc975ea5393..4f6f62605996a6 100644 --- a/xla/service/gpu/kernels/topk_custom_kernel_test.cc +++ b/xla/service/gpu/kernels/topk_custom_kernel_test.cc @@ -32,9 +32,9 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/model/BUILD b/xla/service/gpu/model/BUILD index ca44e81d66fee6..579cde4fd9f5e7 100644 --- a/xla/service/gpu/model/BUILD +++ b/xla/service/gpu/model/BUILD @@ -742,6 +742,7 @@ xla_cc_test( "//xla/service/gpu:hlo_traversal", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -749,7 +750,6 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/xla/service/gpu/model/symbolic_tile_analysis_test.cc index a9680f0f5fdb07..b0287b572c1f51 100644 --- a/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -40,8 +40,8 @@ limitations under the License. #include "xla/service/instruction_fusion.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/nvptx_compiler_test.cc b/xla/service/gpu/nvptx_compiler_test.cc index 642a0cc9eca438..f43066672b2bef 100644 --- a/xla/service/gpu/nvptx_compiler_test.cc +++ b/xla/service/gpu/nvptx_compiler_test.cc @@ -34,9 +34,9 @@ limitations under the License. #include "xla/service/logical_buffer.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD index af8e91f94752d4..28a37b562b98af 100644 --- a/xla/service/gpu/runtime/BUILD +++ b/xla/service/gpu/runtime/BUILD @@ -164,10 +164,10 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", @@ -382,10 +382,10 @@ xla_test( "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -486,9 +486,9 @@ xla_test( "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", diff --git a/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/xla/service/gpu/runtime/command_buffer_cmd_test.cc index 8bdaf6a159eb51..40ba5e35bd9d73 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd_test.cc +++ b/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -35,8 +35,8 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" // IWYU pragma: keep -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/xla/service/gpu/runtime/command_buffer_thunk_test.cc index 9146213d72fe89..f4fc9e22c62c4f 100644 --- a/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -48,9 +48,9 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/profiler/lib/profiler_lock.h" diff --git a/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc b/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc index 9f42ad2efc69f7..75c700b75b4e34 100644 --- a/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc +++ b/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc @@ -47,8 +47,8 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" // IWYU pragma: keep -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/gpu/split_k_gemm_rewriter_test.cc b/xla/service/gpu/split_k_gemm_rewriter_test.cc index 51013b4411bd1f..8c17196090f3ca 100644 --- a/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -37,9 +37,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index 83daaf0eb540a6..be70f7d61cb616 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -212,9 +212,9 @@ xla_test( "//xla/hlo/utils:hlo_query", "//xla/service:executable", "//xla/service:hlo_module_config", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], @@ -246,11 +246,11 @@ xla_test( "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ] + if_cuda_is_configured([ @@ -954,7 +954,7 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ["@tsl//tsl/platform:test_main"], # b/317293391 ), @@ -1039,8 +1039,8 @@ cc_library( deps = [ "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", ], ) diff --git a/xla/service/gpu/tests/gemm_rewrite_test.cc b/xla/service/gpu/tests/gemm_rewrite_test.cc index 0b7bcb3f58227c..f412f1f3d5293d 100644 --- a/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -48,8 +48,8 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA diff --git a/xla/service/gpu/tests/gpu_sparse_dot_test.cc b/xla/service/gpu/tests/gpu_sparse_dot_test.cc index 5ea691967fa6b8..6133a5b38f4bc5 100644 --- a/xla/service/gpu/tests/gpu_sparse_dot_test.cc +++ b/xla/service/gpu/tests/gpu_sparse_dot_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc b/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc index e247abe6872847..cc4c36507fec94 100644 --- a/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc +++ b/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/service/executable.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/tests/simple_optimization_test.cc b/xla/service/gpu/tests/simple_optimization_test.cc index a18d58d6df333c..2ece976d3e5e39 100644 --- a/xla/service/gpu/tests/simple_optimization_test.cc +++ b/xla/service/gpu/tests/simple_optimization_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/triton_fusion_numerics_verifier_test.cc b/xla/service/gpu/triton_fusion_numerics_verifier_test.cc index 1d35d1927b2a58..8703effb6ef5bb 100644 --- a/xla/service/gpu/triton_fusion_numerics_verifier_test.cc +++ b/xla/service/gpu/triton_fusion_numerics_verifier_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla::gpu { namespace { diff --git a/xla/service/gpu_compilation_environment_test.cc b/xla/service/gpu_compilation_environment_test.cc index 072f66b147e287..e684f7c68e4995 100644 --- a/xla/service/gpu_compilation_environment_test.cc +++ b/xla/service/gpu_compilation_environment_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include #include "xla/parse_flags_from_env.h" #include "xla/service/compilation_environments.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" diff --git a/xla/service/heap_simulator/BUILD b/xla/service/heap_simulator/BUILD index 3111a5016b4a55..e9873e948e0989 100644 --- a/xla/service/heap_simulator/BUILD +++ b/xla/service/heap_simulator/BUILD @@ -81,10 +81,10 @@ xla_cc_test( "//xla/service:tuple_points_to_analysis", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", ], diff --git a/xla/service/heap_simulator/heap_simulator_test.cc b/xla/service/heap_simulator/heap_simulator_test.cc index 480213f78e8f7b..27f0261d103b66 100644 --- a/xla/service/heap_simulator/heap_simulator_test.cc +++ b/xla/service/heap_simulator/heap_simulator_test.cc @@ -40,7 +40,7 @@ limitations under the License. #include "xla/service/tuple_points_to_analysis.h" #include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" diff --git a/xla/service/hlo_alias_analysis_test.cc b/xla/service/hlo_alias_analysis_test.cc index 36709bc0a8e79c..ea687b640a58ef 100644 --- a/xla/service/hlo_alias_analysis_test.cc +++ b/xla/service/hlo_alias_analysis_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" diff --git a/xla/service/hlo_dataflow_analysis_test.cc b/xla/service/hlo_dataflow_analysis_test.cc index 5158e068501d5e..967c2955db20c2 100644 --- a/xla/service/hlo_dataflow_analysis_test.cc +++ b/xla/service/hlo_dataflow_analysis_test.cc @@ -41,8 +41,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/hlo_dce_test.cc b/xla/service/hlo_dce_test.cc index c80c286c17ea2d..38a170ae77160d 100644 --- a/xla/service/hlo_dce_test.cc +++ b/xla/service/hlo_dce_test.cc @@ -35,9 +35,9 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/hlo_domain_test.cc b/xla/service/hlo_domain_test.cc index 63bc4aba494d79..13f80fdf6b441b 100644 --- a/xla/service/hlo_domain_test.cc +++ b/xla/service/hlo_domain_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/service/sharding_propagation.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/hlo_input_output_alias_config_test.cc b/xla/service/hlo_input_output_alias_config_test.cc index 5f583c1f38f0bb..8b7a99b385db67 100644 --- a/xla/service/hlo_input_output_alias_config_test.cc +++ b/xla/service/hlo_input_output_alias_config_test.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/hlo_instruction_test.cc b/xla/service/hlo_instruction_test.cc index 6147be229323e0..981b967d997690 100644 --- a/xla/service/hlo_instruction_test.cc +++ b/xla/service/hlo_instruction_test.cc @@ -36,10 +36,10 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/hlo_memory_scheduler_test.cc b/xla/service/hlo_memory_scheduler_test.cc index fef4b71c55b7a9..62a13d14097887 100644 --- a/xla/service/hlo_memory_scheduler_test.cc +++ b/xla/service/hlo_memory_scheduler_test.cc @@ -41,8 +41,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/hlo_module_dce_test.cc b/xla/service/hlo_module_dce_test.cc index c192429c2f30e0..4b1a7b7e2e4409 100644 --- a/xla/service/hlo_module_dce_test.cc +++ b/xla/service/hlo_module_dce_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/hlo_module_group_test.cc b/xla/service/hlo_module_group_test.cc index b56b53b4952e05..007df88bdcc9d9 100644 --- a/xla/service/hlo_module_group_test.cc +++ b/xla/service/hlo_module_group_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/service/hlo_module_group_metadata.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/xla/service/hlo_module_test.cc b/xla/service/hlo_module_test.cc index 8af2621174e7b0..291093213e6719 100644 --- a/xla/service/hlo_module_test.cc +++ b/xla/service/hlo_module_test.cc @@ -36,9 +36,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/hlo_ordering_test.cc b/xla/service/hlo_ordering_test.cc index dc003777ecef6e..743f9f24f20c5e 100644 --- a/xla/service/hlo_ordering_test.cc +++ b/xla/service/hlo_ordering_test.cc @@ -28,9 +28,9 @@ limitations under the License. #include "xla/service/hlo_value.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/hlo_parser_test.cc b/xla/service/hlo_parser_test.cc index 54d24c0436a256..4c67c01c073380 100644 --- a/xla/service/hlo_parser_test.cc +++ b/xla/service/hlo_parser_test.cc @@ -37,9 +37,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/hlo_pass_pipeline_test.cc b/xla/service/hlo_pass_pipeline_test.cc index d5ad880f72c7a7..502406bb54d1fc 100644 --- a/xla/service/hlo_pass_pipeline_test.cc +++ b/xla/service/hlo_pass_pipeline_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/hlo_rematerialization_test.cc b/xla/service/hlo_rematerialization_test.cc index b30cf8293e48e9..c3a945345b3101 100644 --- a/xla/service/hlo_rematerialization_test.cc +++ b/xla/service/hlo_rematerialization_test.cc @@ -35,8 +35,8 @@ limitations under the License. #include "xla/service/hlo_rematerialization_test_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/hlo_replication_analysis_test.cc b/xla/service/hlo_replication_analysis_test.cc index 4cb5b9b8c43792..e57e7112226072 100644 --- a/xla/service/hlo_replication_analysis_test.cc +++ b/xla/service/hlo_replication_analysis_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/hlo_schedule_test.cc b/xla/service/hlo_schedule_test.cc index 4ba1a982def9ef..4f96b30498b1c6 100644 --- a/xla/service/hlo_schedule_test.cc +++ b/xla/service/hlo_schedule_test.cc @@ -31,9 +31,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index 874077fce3e6b2..82f712a64ed34d 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -43,9 +43,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/platform.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/host_memory_transfer_asyncifier_test.cc b/xla/service/host_memory_transfer_asyncifier_test.cc index d054cd115c1370..fd85488a2239ec 100644 --- a/xla/service/host_memory_transfer_asyncifier_test.cc +++ b/xla/service/host_memory_transfer_asyncifier_test.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/host_offload_legalize_test.cc b/xla/service/host_offload_legalize_test.cc index 096f9a10560b44..0322c80a7504cf 100644 --- a/xla/service/host_offload_legalize_test.cc +++ b/xla/service/host_offload_legalize_test.cc @@ -33,8 +33,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace m = ::xla::match; diff --git a/xla/service/host_offloader_test.cc b/xla/service/host_offloader_test.cc index 85cc7742b3ce45..0d2ee3d295df72 100644 --- a/xla/service/host_offloader_test.cc +++ b/xla/service/host_offloader_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace m = ::xla::match; diff --git a/xla/service/host_offloading_prepare_test.cc b/xla/service/host_offloading_prepare_test.cc index 9210d9824231c8..92d5490cfb2d15 100644 --- a/xla/service/host_offloading_prepare_test.cc +++ b/xla/service/host_offloading_prepare_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/layout_assignment_test.cc b/xla/service/layout_assignment_test.cc index 139124bd6c09bb..0b294c46ddef17 100644 --- a/xla/service/layout_assignment_test.cc +++ b/xla/service/layout_assignment_test.cc @@ -44,9 +44,9 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/mapped_ptr_container_sorter_test.cc b/xla/service/mapped_ptr_container_sorter_test.cc index 22f4cd0b8c9205..ca738619aa8ab8 100644 --- a/xla/service/mapped_ptr_container_sorter_test.cc +++ b/xla/service/mapped_ptr_container_sorter_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "absl/functional/bind_front.h" #include "absl/log/log.h" #include "xla/test.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index a8efb1d1050c62..68e91ee27f384b 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -110,6 +110,7 @@ xla_cc_test( "//xla/tests:test_utils", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -120,7 +121,6 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:status", @@ -310,11 +310,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], @@ -355,12 +355,12 @@ xla_cc_test( "//xla/service:hlo_value", "//xla/service/heap_simulator", "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], @@ -465,6 +465,7 @@ xla_cc_test( "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -474,7 +475,6 @@ xla_cc_test( "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@com_googlesource_code_re2//:re2", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/service/memory_space_assignment/cost_analysis_test.cc b/xla/service/memory_space_assignment/cost_analysis_test.cc index e4d93dd8c8f61f..39d4dbbded7bd2 100644 --- a/xla/service/memory_space_assignment/cost_analysis_test.cc +++ b/xla/service/memory_space_assignment/cost_analysis_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc b/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc index b0487cab8fdbb9..1dff5221026f82 100644 --- a/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc +++ b/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc @@ -55,9 +55,9 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 7547901d1aa4a8..3fb558b923acc7 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -78,9 +78,9 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/status.h" diff --git a/xla/service/memory_space_assignment/simulator_test.cc b/xla/service/memory_space_assignment/simulator_test.cc index f3433ce7b569de..3b61a70f9309f5 100644 --- a/xla/service/memory_space_assignment/simulator_test.cc +++ b/xla/service/memory_space_assignment/simulator_test.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/memory_space_propagation_test.cc b/xla/service/memory_space_propagation_test.cc index 940a4ebbcc400e..98ae47c8b164f2 100644 --- a/xla/service/memory_space_propagation_test.cc +++ b/xla/service/memory_space_propagation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/profile_guided_latency_estimator_test.cc b/xla/service/profile_guided_latency_estimator_test.cc index ff2d766b8b07c2..e795c475792c76 100644 --- a/xla/service/profile_guided_latency_estimator_test.cc +++ b/xla/service/profile_guided_latency_estimator_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/latency_hiding_scheduler.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/protobuf/profiled_instructions.pb.h" diff --git a/xla/service/real_imag_expander_test.cc b/xla/service/real_imag_expander_test.cc index a7349a64011d62..429042745427f0 100644 --- a/xla/service/real_imag_expander_test.cc +++ b/xla/service/real_imag_expander_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/reshape_mover_test.cc b/xla/service/reshape_mover_test.cc index 8c1bce4d0103f7..5ad138e1a94302 100644 --- a/xla/service/reshape_mover_test.cc +++ b/xla/service/reshape_mover_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/scatter_expander_test.cc b/xla/service/scatter_expander_test.cc index a74eabf4080ef7..4d135d3bb26dad 100644 --- a/xla/service/scatter_expander_test.cc +++ b/xla/service/scatter_expander_test.cc @@ -26,8 +26,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/slice_sinker_test.cc b/xla/service/slice_sinker_test.cc index cbbdafc877cda2..413710bd6a225b 100644 --- a/xla/service/slice_sinker_test.cc +++ b/xla/service/slice_sinker_test.cc @@ -30,8 +30,8 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/sort_simplifier_test.cc b/xla/service/sort_simplifier_test.cc index ea8f208271a571..678ce7c37eb905 100644 --- a/xla/service/sort_simplifier_test.cc +++ b/xla/service/sort_simplifier_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/spmd/BUILD b/xla/service/spmd/BUILD index d4b8efac5323d3..32713e7c6ddf43 100644 --- a/xla/service/spmd/BUILD +++ b/xla/service/spmd/BUILD @@ -108,6 +108,7 @@ xla_cc_test( "//xla/service:sharding_format_picker", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -115,7 +116,6 @@ xla_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", ], diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 99a1d1f92dc951..da872c2f334b84 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -47,9 +47,9 @@ limitations under the License. #include "xla/service/spmd/spmd_prepare.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/stable_sort_expander_test.cc b/xla/service/stable_sort_expander_test.cc index f2b5c41eee4f17..83ba193ede5aef 100644 --- a/xla/service/stable_sort_expander_test.cc +++ b/xla/service/stable_sort_expander_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc index a1dbb7a5f59b09..c678bef94e373f 100644 --- a/xla/service/topk_rewriter_test.cc +++ b/xla/service/topk_rewriter_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "xla/service/tuple_simplifier.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/service/triangular_solve_expander_test.cc b/xla/service/triangular_solve_expander_test.cc index 777f1258eb1ce1..fa382b24d0d9db 100644 --- a/xla/service/triangular_solve_expander_test.cc +++ b/xla/service/triangular_solve_expander_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "xla/reference_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/tuple_simplifier_test.cc b/xla/service/tuple_simplifier_test.cc index f83a88f50709a1..33305afd7e0f71 100644 --- a/xla/service/tuple_simplifier_test.cc +++ b/xla/service/tuple_simplifier_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/while_loop_all_reduce_code_motion_test.cc b/xla/service/while_loop_all_reduce_code_motion_test.cc index 22928f05f1cac2..271b04cd4e2643 100644 --- a/xla/service/while_loop_all_reduce_code_motion_test.cc +++ b/xla/service/while_loop_all_reduce_code_motion_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/while_loop_concat_code_motion_test.cc b/xla/service/while_loop_concat_code_motion_test.cc index a4baa5bbe4c1e6..83a43f54f7dd05 100644 --- a/xla/service/while_loop_concat_code_motion_test.cc +++ b/xla/service/while_loop_concat_code_motion_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/while_loop_invariant_code_motion_test.cc b/xla/service/while_loop_invariant_code_motion_test.cc index 57f2768b458c0e..7d311df3546e65 100644 --- a/xla/service/while_loop_invariant_code_motion_test.cc +++ b/xla/service/while_loop_invariant_code_motion_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/while_loop_simplifier_test.cc b/xla/service/while_loop_simplifier_test.cc index c82e29c06728eb..494271c2023ceb 100644 --- a/xla/service/while_loop_simplifier_test.cc +++ b/xla/service/while_loop_simplifier_test.cc @@ -33,8 +33,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/service/xla_aot_compile_cpu_test.cc b/xla/service/xla_aot_compile_cpu_test.cc index 16ae69b5bf0e18..85b9b73098466b 100644 --- a/xla/service/xla_aot_compile_cpu_test.cc +++ b/xla/service/xla_aot_compile_cpu_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/platform_util.h" #include "xla/service/shaped_buffer.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/xla_aot_compile_gpu_test.cc b/xla/service/xla_aot_compile_gpu_test.cc index a3720b3b39f15f..88c36c9f4daf47 100644 --- a/xla/service/xla_aot_compile_gpu_test.cc +++ b/xla/service/xla_aot_compile_gpu_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/platform_util.h" #include "xla/service/shaped_buffer.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/xla/service/xla_aot_compile_stablehlo_cpu_test.cc b/xla/service/xla_aot_compile_stablehlo_cpu_test.cc index 7526cd401c71ce..a85de68e143d08 100644 --- a/xla/service/xla_aot_compile_stablehlo_cpu_test.cc +++ b/xla/service/xla_aot_compile_stablehlo_cpu_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/service/platform_util.h" #include "xla/service/shaped_buffer.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 1eddab2a7426fe..24def7c5e53eaf 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -565,10 +565,10 @@ xla_test( "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -683,9 +683,9 @@ xla_test( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:typed_kernel_factory", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -714,11 +714,11 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:trace_command_buffer_factory", "//xla/stream_executor:typed_kernel_factory", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -745,7 +745,7 @@ xla_test( "//xla/stream_executor", "//xla/stream_executor:device_memory", "//xla/stream_executor:platform_manager", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -828,9 +828,9 @@ xla_test( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:protobuf", diff --git a/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/xla/stream_executor/gpu/gpu_command_buffer_test.cc index 306556c2d736e8..cd329353b901f5 100644 --- a/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -37,7 +37,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/trace_command_buffer_factory.h" #include "xla/stream_executor/typed_kernel_factory.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/xla/stream_executor/gpu/gpu_device_info_test.cc b/xla/stream_executor/gpu/gpu_device_info_test.cc index 01681a4746701a..9ecfe692fae457 100644 --- a/xla/stream_executor/gpu/gpu_device_info_test.cc +++ b/xla/stream_executor/gpu/gpu_device_info_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" diff --git a/xla/stream_executor/gpu/gpu_kernel_test.cc b/xla/stream_executor/gpu/gpu_kernel_test.cc index 507fbfa477520f..9d93c1264d9128 100644 --- a/xla/stream_executor/gpu/gpu_kernel_test.cc +++ b/xla/stream_executor/gpu/gpu_kernel_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/typed_kernel_factory.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/stream_executor/gpu/memcpy_test.cc b/xla/stream_executor/gpu/memcpy_test.cc index 96b7700ce33538..1fbe79ce0ec4b2 100644 --- a/xla/stream_executor/gpu/memcpy_test.cc +++ b/xla/stream_executor/gpu/memcpy_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/stream_executor/gpu/redzone_allocator_test.cc b/xla/stream_executor/gpu/redzone_allocator_test.cc index 1ab7dea3030050..abf94db2519ee4 100644 --- a/xla/stream_executor/gpu/redzone_allocator_test.cc +++ b/xla/stream_executor/gpu/redzone_allocator_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/stream_executor/host/BUILD b/xla/stream_executor/host/BUILD index 64152a8058400d..6ab591963fea2c 100644 --- a/xla/stream_executor/host/BUILD +++ b/xla/stream_executor/host/BUILD @@ -140,11 +140,11 @@ xla_cc_test( "//xla/stream_executor:device_memory", "//xla/stream_executor:kernel_spec", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:statusor", @@ -196,9 +196,9 @@ xla_cc_test( "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/stream_executor/host/host_kernel_test.cc b/xla/stream_executor/host/host_kernel_test.cc index 7f0a229f1c92a9..a99c6752dd1e90 100644 --- a/xla/stream_executor/host/host_kernel_test.cc +++ b/xla/stream_executor/host/host_kernel_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" diff --git a/xla/stream_executor/host/host_stream_test.cc b/xla/stream_executor/host/host_stream_test.cc index 522d38781256fd..1f60709ceb4b2f 100644 --- a/xla/stream_executor/host/host_stream_test.cc +++ b/xla/stream_executor/host/host_stream_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 2596ac963e29fc..f725acf5791289 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -213,11 +213,11 @@ cc_library( "//xla/service:platform_util", "//xla/stream_executor", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", ], @@ -299,8 +299,8 @@ cc_library( ":filecheck", "//xla/service:llvm_compiler", "//xla/service/llvm_ir:llvm_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", ], ) @@ -413,7 +413,7 @@ xla_test( "//xla/service:backend", "//xla/service:executable", "//xla/stream_executor:stream_executor_memory_allocator", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -543,8 +543,8 @@ xla_test( "//xla/client/lib:arithmetic", "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_benchmark", @@ -572,11 +572,11 @@ xla_test( "//xla/client:xla_computation", "//xla/service:platform_util", "//xla/service:stream_pool", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:regexp", "@tsl//tsl/platform:test", ], @@ -664,7 +664,7 @@ xla_test( "//xla/client:local_client", "//xla/client:xla_builder", "//xla/client/lib:arithmetic", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", ], ) @@ -1162,7 +1162,7 @@ xla_test( "//xla/client:local_client", "//xla/client:xla_builder", "//xla/client/lib:constants", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", ], @@ -1613,8 +1613,8 @@ xla_test( "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/service:hlo_parser", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", ], ) @@ -1665,12 +1665,12 @@ xla_test( "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/client/lib:arithmetic", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", ], ) @@ -1717,9 +1717,9 @@ xla_test_library( "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/client/lib:arithmetic", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", "@tsl//tsl/platform:test", ], @@ -1912,13 +1912,13 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -2298,9 +2298,9 @@ xla_test( "//xla/service:computation_placer", "//xla/service:executable", "//xla/service:hlo_module_config", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:blocking_counter", "@tsl//tsl/platform:env", ], @@ -2430,7 +2430,7 @@ xla_test( "//xla:shape_util", "//xla:test", "//xla:test_helpers", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2505,9 +2505,9 @@ xla_test( "//xla:test", "//xla:test_helpers", "//xla/service:hlo_proto_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -2535,10 +2535,10 @@ xla_test( "//xla/client:xla_computation", "//xla/client/lib:arithmetic", "//xla/client/lib:prng", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", ], @@ -2562,9 +2562,9 @@ xla_test( "//xla/client:global_data", "//xla/client:xla_builder", "//xla/client:xla_computation", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -2736,9 +2736,9 @@ xla_test( "//xla/client:local_client", "//xla/hlo/ir:hlo", "//xla/service:hlo_runner", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_benchmark", @@ -2832,7 +2832,7 @@ xla_test( ":local_client_test_base", ":test_macros_header", ":xla_internal_test_main", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -3001,9 +3001,9 @@ xla_test( "//xla:shape_util", "//xla/client:xla_builder", "//xla/service:hlo_parser", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_set", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -3038,8 +3038,8 @@ xla_cc_test( "//xla/client:xla_builder", "//xla/service:cpu_plugin", "//xla/stream_executor:platform_manager", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/synchronization", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:test", ], @@ -3118,9 +3118,9 @@ xla_test( "//xla/client:xla_builder", "//xla/client/lib:math", "//xla/client/lib:matrix", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -3144,8 +3144,8 @@ xla_test( "//xla/client:xla_builder", "//xla/client/lib:arithmetic", "//xla/client/lib:matrix", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/lib/core:status_test_util", ], ) diff --git a/xla/tests/buffer_donation_test.cc b/xla/tests/buffer_donation_test.cc index 732d562871afa9..35d9c648846892 100644 --- a/xla/tests/buffer_donation_test.cc +++ b/xla/tests/buffer_donation_test.cc @@ -30,7 +30,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/tests/cholesky_test.cc b/xla/tests/cholesky_test.cc index a0c3d4227c7246..9215319bbf8e40 100644 --- a/xla/tests/cholesky_test.cc +++ b/xla/tests/cholesky_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 0d8c3062bf2238..460864c7513269 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" diff --git a/xla/tests/compute_constant_test.cc b/xla/tests/compute_constant_test.cc index d991b580a26bbd..8742656f17ff7a 100644 --- a/xla/tests/compute_constant_test.cc +++ b/xla/tests/compute_constant_test.cc @@ -31,8 +31,8 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index a926d24819fd68..26407b42790526 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" diff --git a/xla/tests/hlo_test_base.cc b/xla/tests/hlo_test_base.cc index e7367e75a760b9..8c473626f719fe 100644 --- a/xla/tests/hlo_test_base.cc +++ b/xla/tests/hlo_test_base.cc @@ -42,8 +42,8 @@ limitations under the License. #include "xla/tests/pjrt_client_registry.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" diff --git a/xla/tests/llvm_irgen_test_base.cc b/xla/tests/llvm_irgen_test_base.cc index fae82d29c84954..db3d06c69f62dd 100644 --- a/xla/tests/llvm_irgen_test_base.cc +++ b/xla/tests/llvm_irgen_test_base.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/tests/filecheck.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/tests/multioutput_fusion_test.cc b/xla/tests/multioutput_fusion_test.cc index f6ec09a1e35a2f..97ee8b70575426 100644 --- a/xla/tests/multioutput_fusion_test.cc +++ b/xla/tests/multioutput_fusion_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/xla/tests/multiple_devices_on_host_test.cc b/xla/tests/multiple_devices_on_host_test.cc index a24b5594f484bc..8aa1502a3a951d 100644 --- a/xla/tests/multiple_devices_on_host_test.cc +++ b/xla/tests/multiple_devices_on_host_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/shape_util.h" #include "xla/stream_executor/platform_manager.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/xla/tests/multithreaded_compilation_test.cc b/xla/tests/multithreaded_compilation_test.cc index cbbfedab4f7e84..b9cb8d253cb511 100644 --- a/xla/tests/multithreaded_compilation_test.cc +++ b/xla/tests/multithreaded_compilation_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" diff --git a/xla/tests/outfeed_in_nested_computation_test.cc b/xla/tests/outfeed_in_nested_computation_test.cc index b111c620e96f2e..44250f502e7a1b 100644 --- a/xla/tests/outfeed_in_nested_computation_test.cc +++ b/xla/tests/outfeed_in_nested_computation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "xla/tests/local_client_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/tests/pred_test.cc b/xla/tests/pred_test.cc index 89ba59cd70b356..9f8af6a013d677 100644 --- a/xla/tests/pred_test.cc +++ b/xla/tests/pred_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" #include "xla/tests/client_library_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/tests/reduce_test.cc b/xla/tests/reduce_test.cc index 460b920d821eaa..f5db7397cad818 100644 --- a/xla/tests/reduce_test.cc +++ b/xla/tests/reduce_test.cc @@ -57,9 +57,9 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/tests/reduce_window_test.cc b/xla/tests/reduce_window_test.cc index ccbb8f4f3cb8ba..c65cd9c9af1969 100644 --- a/xla/tests/reduce_window_test.cc +++ b/xla/tests/reduce_window_test.cc @@ -35,8 +35,8 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" diff --git a/xla/tests/replicated_io_feed_test.cc b/xla/tests/replicated_io_feed_test.cc index d3600e4602f135..9ee34a7a17da8d 100644 --- a/xla/tests/replicated_io_feed_test.cc +++ b/xla/tests/replicated_io_feed_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" // Tests replicated infeed/outfeed operations. diff --git a/xla/tests/test_utils_test.cc b/xla/tests/test_utils_test.cc index 82a95589e6b907..22212a02998239 100644 --- a/xla/tests/test_utils_test.cc +++ b/xla/tests/test_utils_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/local_client_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/tests/triangular_solve_test.cc b/xla/tests/triangular_solve_test.cc index a82720008f9a42..b04ac99d4110e4 100644 --- a/xla/tests/triangular_solve_test.cc +++ b/xla/tests/triangular_solve_test.cc @@ -29,9 +29,9 @@ limitations under the License. #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/tests/tuple_test.cc b/xla/tests/tuple_test.cc index b0d765414d256c..8d6c1c641579e9 100644 --- a/xla/tests/tuple_test.cc +++ b/xla/tests/tuple_test.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/tests/value_inference_test.cc b/xla/tests/value_inference_test.cc index 4fbc3356b62717..50da08967a01eb 100644 --- a/xla/tests/value_inference_test.cc +++ b/xla/tests/value_inference_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/xla/tests/while_test.cc b/xla/tests/while_test.cc index 7e6c2af60f4183..473875960fdd16 100644 --- a/xla/tests/while_test.cc +++ b/xla/tests/while_test.cc @@ -31,8 +31,8 @@ limitations under the License. #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/xla/tests/xla_hlo_profile_test.cc b/xla/tests/xla_hlo_profile_test.cc index 2436635dea5ef0..72e4387da2beb8 100644 --- a/xla/tests/xla_hlo_profile_test.cc +++ b/xla/tests/xla_hlo_profile_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/regexp.h" #include "tsl/platform/test.h" diff --git a/xla/text_literal_writer_test.cc b/xla/text_literal_writer_test.cc index 3c9c2d6161eef1..e517279a4c447d 100644 --- a/xla/text_literal_writer_test.cc +++ b/xla/text_literal_writer_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/test.h" #include "xla/test_helpers.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" namespace xla { diff --git a/xla/tools/BUILD b/xla/tools/BUILD index 0bbc059f110b0c..4ba990420787c5 100644 --- a/xla/tools/BUILD +++ b/xla/tools/BUILD @@ -485,7 +485,7 @@ xla_cc_test( ":hlo_module_loader", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:test", ], ) @@ -587,7 +587,7 @@ xla_cc_test( "//xla:literal", "//xla:literal_util", "//xla:xla_data_proto_cc", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:test", "@tsl//tsl/platform:test_main", @@ -642,8 +642,8 @@ xla_cc_test( "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", @@ -692,8 +692,8 @@ xla_cc_test( "//xla/service/spmd:spmd_partitioner", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@tsl//tsl/lib/core:status_test_util", ], ) @@ -837,10 +837,10 @@ xla_test( "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:env_time", "@tsl//tsl/platform:path", diff --git a/xla/tools/hlo_control_flow_flattening_test.cc b/xla/tools/hlo_control_flow_flattening_test.cc index a391a59ebdad34..ceb51bee5cbbfe 100644 --- a/xla/tools/hlo_control_flow_flattening_test.cc +++ b/xla/tools/hlo_control_flow_flattening_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/service/hlo_verifier.h" #include "xla/service/spmd/spmd_partitioner.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/tools/hlo_module_loader_test.cc b/xla/tools/hlo_module_loader_test.cc index e3916a0ec98ac9..16fbe45e4ae451 100644 --- a/xla/tools/hlo_module_loader_test.cc +++ b/xla/tools/hlo_module_loader_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/xla/tools/multihost_hlo_runner/BUILD b/xla/tools/multihost_hlo_runner/BUILD index 81834f83b82116..be73ccc055cc54 100644 --- a/xla/tools/multihost_hlo_runner/BUILD +++ b/xla/tools/multihost_hlo_runner/BUILD @@ -198,6 +198,7 @@ xla_test( "//xla:xla_proto_cc", "//xla/pjrt:pjrt_client", "//xla/tests:filecheck", + "//xla/tsl/lib/core:status_test_util", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -206,7 +207,6 @@ xla_test( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:path", diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 0d043ad757c1d0..b55c1be6b7a5b8 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -30,9 +30,9 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/tests/filecheck.h" #include "xla/tools/multihost_hlo_runner/create_client.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" diff --git a/xla/tools/run_hlo_module_bin_test.cc b/xla/tools/run_hlo_module_bin_test.cc index bb4d6b32cdcf34..fe82122cb3afe5 100644 --- a/xla/tools/run_hlo_module_bin_test.cc +++ b/xla/tools/run_hlo_module_bin_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo_parser.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/xla/tools/run_hlo_module_test.cc b/xla/tools/run_hlo_module_test.cc index 2ac2d2b6f12074..255563a5893657 100644 --- a/xla/tools/run_hlo_module_test.cc +++ b/xla/tools/run_hlo_module_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tools/run_hlo_module.pb.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/xla/tools/xla_compile_lib_test.cc b/xla/tools/xla_compile_lib_test.cc index d9586238dd2247..101282cbc03ff5 100644 --- a/xla/tools/xla_compile_lib_test.cc +++ b/xla/tools/xla_compile_lib_test.cc @@ -33,8 +33,8 @@ limitations under the License. #include "xla/stream_executor/device_description.pb.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/env_time.h" #include "tsl/platform/path.h" diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD index d9c0ddc3cfc6c9..043dbf0cc801d7 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD @@ -148,9 +148,9 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla:types", + "//xla/tsl/lib/core:status_test_util", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/translate/hlo_to_mhlo/hlo_utils_test.cc b/xla/translate/hlo_to_mhlo/hlo_utils_test.cc index c5a4e1a6c0e5d6..b16e5870e99d79 100644 --- a/xla/translate/hlo_to_mhlo/hlo_utils_test.cc +++ b/xla/translate/hlo_to_mhlo/hlo_utils_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/test.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/xla/tsl/distributed_runtime/coordination/BUILD b/xla/tsl/distributed_runtime/coordination/BUILD index d42d285d1d9a69..6f75617aec857e 100644 --- a/xla/tsl/distributed_runtime/coordination/BUILD +++ b/xla/tsl/distributed_runtime/coordination/BUILD @@ -113,6 +113,7 @@ tsl_cc_test( ":coordination_service_impl", ":test_device_proto_cc", "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -120,7 +121,6 @@ tsl_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:random", "@tsl//tsl/platform:status", @@ -167,12 +167,12 @@ tsl_cc_test( ":coordination_client", ":coordination_service_agent", "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/time", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:env_impl", "@tsl//tsl/platform:status", @@ -217,13 +217,13 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env", "@tsl//tsl/platform:env_impl", "@tsl//tsl/platform:status", diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc index 6348054527fdb8..ee2eb2348cac16 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc index da40248891f372..3ec3290c9507e1 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index b9b9bbf75215f4..0fb11db9b8e990 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -36,7 +36,7 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/distributed_runtime/coordination/test_device.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status.h" diff --git a/xla/tsl/distributed_runtime/rpc/BUILD b/xla/tsl/distributed_runtime/rpc/BUILD index 0f9a93eb1a9922..817c4dc5a4c970 100644 --- a/xla/tsl/distributed_runtime/rpc/BUILD +++ b/xla/tsl/distributed_runtime/rpc/BUILD @@ -108,8 +108,8 @@ tsl_cc_test( ], deps = [ ":grpc_channel", + "//xla/tsl/lib/core:status_test_util", "//xla/tsl/util:device_name_utils", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:env_impl", "@tsl//tsl/platform:strcat", "@tsl//tsl/platform:test", diff --git a/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc b/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc index 806ea5494d90c8..80c976640fa6f1 100644 --- a/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc +++ b/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" #include "tsl/protobuf/rpc_options.pb.h" diff --git a/xla/tsl/framework/BUILD b/xla/tsl/framework/BUILD index 8fa1ca738fabc7..1e6ae7e269bf68 100644 --- a/xla/tsl/framework/BUILD +++ b/xla/tsl/framework/BUILD @@ -462,8 +462,8 @@ tsl_cc_test( deps = [ ":device_id_impl", ":device_id_utils", + "//xla/tsl/lib/core:status_test_util", "//xla/tsl/util:device_name_utils", - "@tsl//tsl/lib/core:status_test_util", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:test_main", "@tsl//tsl/protobuf:error_codes_proto_impl_cc", diff --git a/xla/tsl/framework/device_id_utils_test.cc b/xla/tsl/framework/device_id_utils_test.cc index 245097e01f80d4..da12c3bb11c912 100644 --- a/xla/tsl/framework/device_id_utils_test.cc +++ b/xla/tsl/framework/device_id_utils_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include "xla/tsl/framework/device_id_manager.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" namespace tsl { diff --git a/xla/tsl/util/BUILD b/xla/tsl/util/BUILD index 653d91027b2e4c..103122ea6fb58f 100644 --- a/xla/tsl/util/BUILD +++ b/xla/tsl/util/BUILD @@ -273,7 +273,7 @@ tsl_cc_test( srcs = ["device_name_utils_test.cc"], deps = [ ":device_name_utils", - "@tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:strcat", "@tsl//tsl/platform:test", diff --git a/xla/tsl/util/device_name_utils_test.cc b/xla/tsl/util/device_name_utils_test.cc index 1f5f5114550d40..1457297599d74f 100644 --- a/xla/tsl/util/device_name_utils_test.cc +++ b/xla/tsl/util/device_name_utils_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" From 8ab2aac75ca5e9ff96441334fac95662b09ccca3 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Tue, 30 Jul 2024 16:12:59 -0700 Subject: [PATCH 290/376] Remove default argument from StreamExecutor::CreateStream method, which is against the Google style guide. PiperOrigin-RevId: 657765070 --- xla/backends/interpreter/executor.h | 3 +-- xla/stream_executor/gpu/gpu_executor.h | 4 ++-- xla/stream_executor/stream_executor.h | 6 ++++-- xla/stream_executor/tpu/tpu_executor.h | 4 ++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/xla/backends/interpreter/executor.h b/xla/backends/interpreter/executor.h index 822537638b09cb..a899bba70073fb 100644 --- a/xla/backends/interpreter/executor.h +++ b/xla/backends/interpreter/executor.h @@ -151,8 +151,7 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { } absl::StatusOr> CreateStream( - std::optional> priority = - std::nullopt) override { + std::optional> priority) override { return std::make_unique(this); } diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index bf1027d249ab58..e7926f01e0349f 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -30,6 +30,7 @@ limitations under the License. #include #include #include +#include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" @@ -223,8 +224,7 @@ class GpuExecutor : public StreamExecutorCommon { absl::StatusOr> CreateEvent() override; absl::StatusOr> CreateStream( - std::optional> priority = - std::nullopt) override; + std::optional> priority) override; absl::StatusOr> CreateCommandBuffer( CommandBuffer::Mode mode) override; diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index a634dc0f901b74..d4f264a3de961d 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -84,8 +84,10 @@ class StreamExecutor { // Creates and initializes a Stream. virtual absl::StatusOr> CreateStream( - std::optional> priority = - std::nullopt) = 0; + std::optional> priority) = 0; + absl::StatusOr> CreateStream() { + return CreateStream(std::nullopt); + } // Creates and initializes an Event. virtual absl::StatusOr> CreateEvent() = 0; diff --git a/xla/stream_executor/tpu/tpu_executor.h b/xla/stream_executor/tpu/tpu_executor.h index 88ef4618424732..c969c6d5d59d51 100644 --- a/xla/stream_executor/tpu/tpu_executor.h +++ b/xla/stream_executor/tpu/tpu_executor.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" @@ -90,8 +91,7 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { const override; absl::StatusOr> CreateStream( - std::optional> priority = - std::nullopt) override; + std::optional> priority) override; absl::StatusOr> CreateEvent() override; From 95c79d4404904378bfef22f9930a5ed06875ebc4 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Tue, 30 Jul 2024 16:13:09 -0700 Subject: [PATCH 291/376] Disable ExhaustiveOpTestBase subnormal cache for binary tests PiperOrigin-RevId: 657765126 --- .../exhaustive/exhaustive_op_test_utils.cc | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/xla/tests/exhaustive/exhaustive_op_test_utils.cc index 52bdda0c6a3278..ee7207a1433ab7 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -375,17 +375,21 @@ void ExhaustiveOpTestBase::ExpectNear( OutputRangeCheck check_valid_range) { // Cache for when all components are subnormal testing values. std::vector pure_subnormal_cache; - // Since we take the cross product of all possible test values, and each - // component has kNumSubnormalSubstitutionValues possible test values, then - // the total number of different cache locations are - // kNumSubnormalSubstitutionValues raised to the num_components. - // num_components = N for the reals, and 2*N for the complex. - int64_t max_cache_size = - pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1)); - pure_subnormal_cache.reserve(max_cache_size); - for (int i = 0; i < max_cache_size; ++i) { - pure_subnormal_cache.push_back(CallOperation( - evaluate_op, FromCacheLocation(i))); + // TODO(b/353790524): Subnormal cache does not seem to work properly with + // more than 1 input. + if constexpr (N == 1) { + // Since we take the cross product of all possible test values, and each + // component has kNumSubnormalSubstitutionValues possible test values, then + // the total number of different cache locations are + // kNumSubnormalSubstitutionValues raised to the num_components. + // num_components = N for the reals, and 2*N for the complex. + int64_t max_cache_size = + pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1)); + pure_subnormal_cache.reserve(max_cache_size); + for (int i = 0; i < max_cache_size; ++i) { + pure_subnormal_cache.push_back(CallOperation( + evaluate_op, FromCacheLocation(i))); + } } NativeInputsList inputs_arr; @@ -450,13 +454,19 @@ void ExhaustiveOpTestBase::ExpectNear( for (NativeRefInputs test_value : subnormal_test_inputs) { NativeRefT result; - int cache_loc = - GetCacheLocation( - test_value); - if (cache_loc == kInvalidCacheIndex) { - result = CallOperation(evaluate_op, test_value); + // TODO(b/353790524): Subnormal cache does not seem to work properly with + // more than 1 input. + if constexpr (N == 1) { + int cache_loc = + GetCacheLocation(test_value); + if (cache_loc == kInvalidCacheIndex) { + result = CallOperation(evaluate_op, test_value); + } else { + result = pure_subnormal_cache[cache_loc]; + } } else { - result = pure_subnormal_cache[cache_loc]; + result = result = CallOperation(evaluate_op, test_value); } if (IsClose(result, static_cast(actual), error_spec)) { From 7c429df2b7f91bf1dd0bef0079fa17edcb32e933 Mon Sep 17 00:00:00 2001 From: zoranjovanovic-ns <126815388+zoranjovanovic-ns@users.noreply.github.com> Date: Tue, 30 Jul 2024 16:29:43 -0700 Subject: [PATCH 292/376] PR #14962: [ROCm] Fix an issue with Softmax. Imported from GitHub PR https://github.com/openxla/xla/pull/14962 Copybara import of the project: -- 3637d6ba4c0913d6f3d83f71d542a97234c45523 by Zoran Jovanovic : [ROCm] Fix an issue with Softmax. Merging this change closes #14962 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/14962 from ROCm:ci_rocm_softmax 3637d6ba4c0913d6f3d83f71d542a97234c45523 PiperOrigin-RevId: 657769988 --- xla/service/gpu/gpu_compiler.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 73c3556019d6c2..96b537e066fc54 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1421,8 +1421,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // ReductionDimensionGrouper, as that makes matching the softmax pattern // harder. if (debug_options.xla_gpu_enable_triton_softmax_fusion() && - cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { + ((cuda_cc != nullptr && + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || + rocm_cc != nullptr)) { // Triton compilation needs normalized operations on bf16 (i.e. converted // to f32). add_float_normalization(pipeline); From 712eb9064b91e18d26ae293c0918379e6b75241b Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 30 Jul 2024 16:30:41 -0700 Subject: [PATCH 293/376] PR #15490: Do not fail GpuPerformanceWithCollectiveModel on Blackwell Imported from GitHub PR https://github.com/openxla/xla/pull/15490 Add bandwidth data for Blackwell in the collective model (the actual numbers are unknown at this point, so just copied from Hopper). Note that "kLowLatencyMaxBandwidths" value for Hopper was incorrect - fixed it. Additionally, if nvml doesn't support the card, return "unsupported" instead of CHECK-failing. Copybara import of the project: -- 9340a2fd656898f863c67bcf9a622067b31f9dcc by Sergey Kozub : Do not fail GpuPerformanceWithCollectiveModel on Blackwell Merging this change closes #15490 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15490 from openxla:skozub/gpu_collective_performance_model 9340a2fd656898f863c67bcf9a622067b31f9dcc PiperOrigin-RevId: 657770290 --- .../gpu/model/gpu_collective_performance_model.cc | 8 ++++++-- .../gpu/model/gpu_collective_performance_model.h | 12 ++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/xla/service/gpu/model/gpu_collective_performance_model.cc b/xla/service/gpu/model/gpu_collective_performance_model.cc index f04771f789691a..9459d0c33a7b6a 100644 --- a/xla/service/gpu/model/gpu_collective_performance_model.cc +++ b/xla/service/gpu/model/gpu_collective_performance_model.cc @@ -111,8 +111,11 @@ float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, return bandwidths_table[1]; case se::CudaComputeCapability::HOPPER: return bandwidths_table[2]; + case se::CudaComputeCapability::BLACKWELL: + return bandwidths_table[3]; + default: + return bandwidths_table[4]; } - return -1; } } // namespace @@ -189,7 +192,8 @@ GpuPerformanceWithCollectiveModel::CheckIfNvlinkSupportsP2P() { nvmlReturn_t nvlink_cap_result = xla_nvmlDeviceGetNvLinkCapability( nvml_device, /*nvlink link number*/ 0, NVML_NVLINK_CAP_P2P_SUPPORTED, &supported_p2p); - CHECK(nvlink_cap_result == NVML_SUCCESS); + CHECK(nvlink_cap_result == NVML_SUCCESS || + nvlink_cap_result == NVML_ERROR_NOT_SUPPORTED); CHECK(ShutdownNvml()) << "NVML shutdown failed."; return supported_p2p; #else diff --git a/xla/service/gpu/model/gpu_collective_performance_model.h b/xla/service/gpu/model/gpu_collective_performance_model.h index c11a78c684e80d..49fe21a2c17919 100644 --- a/xla/service/gpu/model/gpu_collective_performance_model.h +++ b/xla/service/gpu/model/gpu_collective_performance_model.h @@ -57,16 +57,16 @@ class GpuPerformanceWithCollectiveModel : public GpuPerformanceModelBase { // Table for max system bandwidths GB/s for using NCCL's low latency // algorithm. This is used for intra-node estimate. - static constexpr std::array kLowLatencyMaxBandwidths = { - 39.0 /* Volta*/, 87.7 /* Ampere*/, 87.7 /* Hopper*/ + static constexpr std::array kLowLatencyMaxBandwidths = { + 39.0 /* Volta */, 87.7 /* Ampere */, 141.0 /* Hopper */, + 141.0 /* Blackwell */, 141.0 /* next-gen */, }; // Max bandwidth in GB/s for ring low latency 128 algorithm per channel on a // single-node - static constexpr std::array kPerChannelMaxRingLL128Bandwidths = { - 20.0 /* Volta */, - 20.0 /* Ampere */, - 36.7 /* Hopper */, + static constexpr std::array kPerChannelMaxRingLL128Bandwidths = { + 20.0 /* Volta */, 20.0 /* Ampere */, 36.7 /* Hopper */, + 36.7 /* Blackwell */, 36.7 /* next-gen */, }; // Nvlink unidirectional bandwidth for different compute cap. Note this is per From 09cc2b9b43b0883bdcb0d32325056a0fb39d9327 Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Tue, 30 Jul 2024 16:46:43 -0700 Subject: [PATCH 294/376] Fix headers in PjRt. PiperOrigin-RevId: 657775121 --- xla/pjrt/BUILD | 22 +++++++++++++++++++++- xla/pjrt/pjrt_client.h | 4 ++++ xla/pjrt/pjrt_stream_executor_client.h | 1 + xla/pjrt/tf_pjrt_client.h | 19 +++++++++++++++++++ xla/python/py_client.cc | 2 +- 5 files changed, 46 insertions(+), 2 deletions(-) diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index 34948f2e3cb888..1b6716882b2527 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -219,6 +219,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -226,6 +228,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -471,6 +474,7 @@ cc_library( deps = [ ":event_pool", ":host_callback", + ":host_memory_spaces", ":local_device_state", ":metrics", ":mlir_to_hlo", @@ -495,7 +499,6 @@ cc_library( "//xla/client:local_client", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", - "//xla/pjrt:host_memory_spaces", "//xla/pjrt/distributed:protocol_proto_cc", "//xla/service:compiler", "//xla/service:computation_layout", @@ -516,6 +519,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -846,12 +850,28 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":pjrt_client", + ":pjrt_common", + ":pjrt_compiler", + ":pjrt_executable", ":pjrt_future", + "//xla:literal", + "//xla:shape_util", + "//xla:util", + "//xla/client:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/service:computation_placer_hdr", + "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:casts", "@tsl//tsl/platform:errors", ], ) diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index d39d40cf86f50f..2f6c7a8b515792 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -31,8 +31,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" @@ -55,6 +58,7 @@ limitations under the License. #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" // API notes: // PjRt stands for "Pretty much Just another RunTime". diff --git a/xla/pjrt/pjrt_stream_executor_client.h b/xla/pjrt/pjrt_stream_executor_client.h index 4e8595591c8356..1fb61152bdab61 100644 --- a/xla/pjrt/pjrt_stream_executor_client.h +++ b/xla/pjrt/pjrt_stream_executor_client.h @@ -33,6 +33,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" diff --git a/xla/pjrt/tf_pjrt_client.h b/xla/pjrt/tf_pjrt_client.h index 363c0526f0ba56..c4299d37b0e0fe 100644 --- a/xla/pjrt/tf_pjrt_client.h +++ b/xla/pjrt/tf_pjrt_client.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_PJRT_TF_PJRT_CLIENT_H_ #define XLA_PJRT_TF_PJRT_CLIENT_H_ +#include +#include #include #include #include @@ -26,9 +28,26 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/client/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/service/computation_placer.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/util.h" +#include "tsl/platform/casts.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index 3b4ebcd9901d09..716c59ba9a2c5d 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -34,7 +34,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinOps.h" @@ -67,6 +66,7 @@ limitations under the License. #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" From a4f6b062f534bc4a27b2d3b9d245f8cb383dca44 Mon Sep 17 00:00:00 2001 From: Heiner Date: Tue, 30 Jul 2024 16:53:28 -0700 Subject: [PATCH 295/376] Add missing target to gloo.BUILD This should've been added in https://github.com/openxla/xla/pull/15027 but was erroneously deleted while patching internally by ddunl PiperOrigin-RevId: 657776799 --- third_party/gloo/gloo.BUILD | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/third_party/gloo/gloo.BUILD b/third_party/gloo/gloo.BUILD index 2eb62cd7416136..dcd436d332d6fd 100644 --- a/third_party/gloo/gloo.BUILD +++ b/third_party/gloo/gloo.BUILD @@ -95,3 +95,14 @@ cc_library( copts = ["-fexceptions"], deps = [":gloo"], ) + +cc_library( + name = "transport_uv", + srcs = glob(["gloo/transport/uv/*.cc"]), + hdrs = glob(["gloo/transport/uv/*.h"]), + copts = ["-fexceptions"], + deps = [ + ":gloo", + "@uv", + ], +) From 2cea900ea8048383e80bf6f5df45e78e3dbc60ae Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Tue, 30 Jul 2024 16:56:44 -0700 Subject: [PATCH 296/376] Add ErrorSpec::skip_comparison for exhaustive tests PiperOrigin-RevId: 657777584 --- .../exhaustive/exhaustive_op_test_utils.cc | 28 +++++++++++++++++++ .../exhaustive/exhaustive_op_test_utils.h | 7 +++++ 2 files changed, 35 insertions(+) diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/xla/tests/exhaustive/exhaustive_op_test_utils.cc index ee7207a1433ab7..7e8f8718b454dc 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -347,6 +347,22 @@ std::string StringifyNum(const std::array& inputs) { return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")"); } +template +void PrintSkipped(int64_t* skipped, const ErrorGenerator& err_generator) { + // We send some fixed amount of skipped messages to the log. The remainder we + // squelch unless we're at vlog level 2. + constexpr int64_t kMaxMismatchesLoggedToErr = 1000; + + (*skipped)++; + if (*skipped < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) { + LOG(WARNING) << err_generator(); + } else if (*skipped == kMaxMismatchesLoggedToErr) { + LOG(WARNING) << "Not printing any more skipped messages; pass " + "--vmodule=exhaustive_op_test=2 to see " + "all of them."; + } +} + template void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) { // We send a few mismatches to gunit so they show up nicely in test logs. @@ -366,6 +382,7 @@ void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) { "all of them."; } } + } // namespace template @@ -400,6 +417,7 @@ void ExhaustiveOpTestBase::ExpectNear( absl::Span result_arr = result_literal.data(); + int64_t skipped = 0; int64_t mismatches = 0; for (int64_t i = 0; i < result_arr.size(); ++i) { @@ -416,6 +434,16 @@ void ExhaustiveOpTestBase::ExpectNear( static_cast(CallOperation(evaluate_op, inputs_ref_ty)); ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs); + if (error_spec.skip_comparison) { + PrintSkipped(&skipped, [&] { + return absl::StrFormat( + "skipping tolerance check for input %s due to " + "ErrorSpec::skip_comparison", + StringifyNum(inputs)); + }); + continue; + } + if (check_valid_range != nullptr && !check_valid_range(inputs, actual)) { PrintMismatch(&mismatches, [&] { return absl::StrFormat( diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.h b/xla/tests/exhaustive/exhaustive_op_test_utils.h index 223ae2704e6b68..c6aece69a1d027 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -140,6 +140,10 @@ struct ErrorSpec { // spec; this only covers the case when both `expected` and `actual` are // equal to 0. bool strict_signed_zeros = false; + // If true, this will skip comparing the output of the test to the expected + // value. This should be used only as a last resort, since it is effectively + // turning off the test for a specific input value set. + bool skip_comparison = false; }; // Representations of the reference function passed in by the user. @@ -617,6 +621,9 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // Testing will ignore inputs for which known_incorrect_fn_ returns true. // The argument to the function is the raw bits for the data being test, // zero extended to 64 bits if the data type is less than 64 bits. + // + // DEPRECATED: Please see ErrorSpec::skip_comparison for an easier framework + // to skip nearness checks for certain unary or binary inputs. std::function known_incorrect_fn_; // If true, allows denormals to be flushed to non-sign-preserving 0. From 9bb18711c56c5538e49f4a38d44c06d8c397a7c1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 17:27:31 -0700 Subject: [PATCH 297/376] Broke some internal tests Reverts e3bcd73cba58284504849d26ac21bd795ca49cd6 PiperOrigin-RevId: 657786386 --- xla/translate/hlo_to_mhlo/BUILD | 1 + .../hlo_to_mhlo/hlo_function_importer.cc | 55 ----------- .../hlo_to_mhlo/hlo_function_importer.h | 7 -- .../hlo_to_mhlo/hlo_module_importer.cc | 4 - .../hlo_to_mhlo/tests/module_attributes.hlo | 13 --- xla/translate/mhlo_to_hlo/BUILD | 1 - .../mhlo_to_hlo/attribute_exporter.cc | 95 ------------------- .../mhlo_to_hlo/attribute_exporter.h | 4 - xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 7 -- .../mhlo_to_hlo/tests/module_attributes.mlir | 42 -------- 10 files changed, 1 insertion(+), 228 deletions(-) diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD index 043dbf0cc801d7..cb3d691eb839a0 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD @@ -83,6 +83,7 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", + "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index c641270aed10c8..98fcd5172728f6 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -56,7 +56,6 @@ limitations under the License. #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -2501,60 +2500,6 @@ absl::Status HloFunctionImporter::ConvertShapeToMlirLayout( return Internal("Couldn't convert layout."); } -// std::string FrontendAttributesToString( -// const FrontendAttributes& frontend_attributes) { -// std::vector> sorted_attributes( -// frontend_attributes.map().begin(), frontend_attributes.map().end()); -// absl::c_sort(sorted_attributes); -// const auto formatter = [](std::string* out, -// const std::pair& item) -// { -// if (LexesAsJsonDict(item.second)) { -// absl::StrAppend(out, item.first, "=", item.second); -// } else { -// absl::StrAppend(out, item.first, "=\"", item.second, "\""); -// } -// }; -// return absl::StrFormat("{%s}", -// absl::StrJoin(sorted_attributes, ",", formatter)); -// } - -mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, - mlir::Builder* builder) { - llvm::SmallVector element_attrs; - alias.ForEachAlias([&](const ShapeIndex& output_index, - const HloInputOutputAliasConfig::Alias& alias) { - std::string kindToString; - switch (alias.kind) { - case HloInputOutputAliasConfig::AliasKind::kMayAlias: - kindToString = "may_alias"; - break; - case HloInputOutputAliasConfig::AliasKind::kMustAlias: - kindToString = "must_alias"; - break; - default: - kindToString = "undefined_alias"; - } - mlir::NamedAttribute alias_named_attributes[3] = { - builder->getNamedAttr( - "parameter_index", - builder->getDenseI64ArrayAttr(ArrayRef( - alias.parameter_index.begin(), alias.parameter_index.end()))), - builder->getNamedAttr("parameter_number", builder->getI64IntegerAttr( - alias.parameter_number)), - builder->getNamedAttr("kind", builder->getStringAttr(kindToString))}; - - mlir::NamedAttribute named_attributes[2] = { - builder->getNamedAttr("output_index", - builder->getDenseI64ArrayAttr(ArrayRef( - output_index.begin(), output_index.end()))), - builder->getNamedAttr( - "alias", builder->getDictionaryAttr(alias_named_attributes))}; - element_attrs.push_back(builder->getDictionaryAttr(named_attributes)); - }); - return builder->getArrayAttr(element_attrs); -} - mlir::Attribute ConvertSharding(const HloSharding& sharding, mlir::Builder* builder) { return builder->getStringAttr(sharding.ToString(/*include_metadata=*/true)); diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/xla/translate/hlo_to_mhlo/hlo_function_importer.h index 5c5a4e309bfbf6..cb3953990f4030 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -33,7 +33,6 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/ValueRange.h" #include "xla/comparison_util.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" @@ -298,12 +297,6 @@ class HloFunctionImporter { bool flatten_computation_args_result_; }; -// Returns a StringAttr that carries a prettyprinted representation of the -// given HLO C++ input_output_alias_config. -// Always succeeds and returns a non-empty attribute. -mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, - mlir::Builder* builder); - // Returns a StringAttr that carries a prettyprinted representation of the // given HLO C++ sharding. // Always succeeds and returns a non-empty attribute. diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index 76037442d52099..1f2ea997c81e8a 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -122,10 +122,6 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { ConvertSharding(hlo_module.spmd_output_sharding(), &builder_)); } - module->setAttr("mhlo.input_output_alias", - ConvertInputOutputAlias( - hlo_module.input_output_alias_config(), &builder_)); - if (hlo_module.has_spmd_parameters_shardings()) { llvm::SmallVector parameter_shardings; parameter_shardings.reserve(hlo_module.spmd_parameters_shardings().size()); diff --git a/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo b/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo index d3433dce372cbf..74eaaea5a0e8fe 100644 --- a/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo +++ b/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo @@ -5,18 +5,6 @@ # FLATTEN-CHECK-LABEL: module @main attributes { hlo_module { name: "main" - input_output_alias { - entries { - output_shape_index: 0 - parameter_number: 0 - kind: MAY_ALIAS - } - entries { - output_shape_index: 1 - parameter_number: 1 - kind: MAY_ALIAS - } - } entry_computation_name: "main.5" computations { name: "main.5" @@ -229,7 +217,6 @@ hlo_module { value: "attr_value" } } -# CHECK-SAME: mhlo.input_output_alias = [{alias = {kind = "may_alias", parameter_index = array, parameter_number = 0 : i64}, output_index = array}, {alias = {kind = "may_alias", parameter_index = array, parameter_number = 1 : i64}, output_index = array}] # CHECK-SAME: mhlo.is_dynamic = true is_dynamic: true # CHECK-SAME: mhlo.use_auto_spmd_partitioning = true diff --git a/xla/translate/mhlo_to_hlo/BUILD b/xla/translate/mhlo_to_hlo/BUILD index 3de8007804af4b..92b7265298f6e7 100644 --- a/xla/translate/mhlo_to_hlo/BUILD +++ b/xla/translate/mhlo_to_hlo/BUILD @@ -23,7 +23,6 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", "//xla/mlir_hlo", "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/xla/translate/mhlo_to_hlo/attribute_exporter.cc index 8d54de6a5b9322..a492861b28d831 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.cc @@ -185,99 +185,4 @@ std::optional ConvertSharding(llvm::StringRef sharding) { return std::nullopt; } -std::optional ConvertInputOutputAlias( - llvm::ArrayRef aliasing) { - if (aliasing.empty()) return std::nullopt; - - xla::HloInputOutputAliasProto input_output_alias_proto; - for (auto attr : aliasing) { - auto entry_attr = mlir::cast(attr); - auto alias_attr = mlir::cast(entry_attr.get("alias")); - mlir::ArrayRef output_index = - mlir::cast(entry_attr.get("output_index")) - .asArrayRef(); - mlir::ArrayRef parameter_index = - mlir::cast(alias_attr.get("parameter_index")) - .asArrayRef(); - HloInputOutputAliasProto::AliasEntryProto entry; - entry.mutable_output_shape_index()->Add(output_index.begin(), - output_index.end()); - entry.set_parameter_number( - mlir::cast(alias_attr.get("parameter_number")) - .getInt()); - entry.mutable_parameter_shape_index()->Add(parameter_index.begin(), - parameter_index.end()); - mlir::StringRef kind = - mlir::cast(alias_attr.get("kind")).getValue(); - if (kind == "may_alias") - entry.set_kind(xla::Kind::MAY_ALIAS); - else if (kind == "must_alias") - entry.set_kind(xla::Kind::MUST_ALIAS); - else - entry.set_kind(xla::Kind::UNDEFINED_ALIAS); - input_output_alias_proto.add_entries()->Swap(&entry); - } - return input_output_alias_proto; -} - -DotDimensionNumbers ConvertDotDimensionNumbers( - mlir::mhlo::DotDimensionNumbersAttr input) { - DotDimensionNumbers output; - - for (auto v : input.getLhsBatchingDimensions()) { - output.add_lhs_batch_dimensions(v); - } - - for (auto v : input.getRhsBatchingDimensions()) { - output.add_rhs_batch_dimensions(v); - } - - for (auto v : input.getLhsContractingDimensions()) { - output.add_lhs_contracting_dimensions(v); - } - - for (auto v : input.getRhsContractingDimensions()) { - output.add_rhs_contracting_dimensions(v); - } - - return output; -} - -DotDimensionNumbers ConvertDotDimensionNumbers( - absl::Span lhs_batch, absl::Span lhs_contract, - absl::Span rhs_batch, - absl::Span rhs_contract) { - DotDimensionNumbers output; - for (auto v : lhs_batch) { - output.add_lhs_batch_dimensions(v); - } - - for (auto v : rhs_batch) { - output.add_rhs_batch_dimensions(v); - } - - for (auto v : lhs_contract) { - output.add_lhs_contracting_dimensions(v); - } - - for (auto v : rhs_contract) { - output.add_rhs_contracting_dimensions(v); - } - - return output; -} - -absl::StatusOr> ConvertMlirArrayAttrToInt64Array( - const mlir::ArrayAttr& array) { - int rank = array.size(); - std::vector converted_array(rank); - for (int i = 0; i < rank; i++) { - mlir::IntegerAttr attr = mlir::dyn_cast(array[i]); - if (!attr) { - return Internal("Type Error: Expected layout integer attribute"); - } - converted_array[i] = attr.getInt(); - } - return converted_array; -} } // namespace xla diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.h b/xla/translate/mhlo_to_hlo/attribute_exporter.h index 49daefe6935650..e0e0dc9821d21e 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.h @@ -20,7 +20,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "mlir/IR/Attributes.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" @@ -60,8 +59,5 @@ ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); // Will fail if both attempts at parsing failed. std::optional ConvertSharding(mlir::StringRef sharding); -std::optional ConvertInputOutputAlias( - llvm::ArrayRef aliasing); - } // namespace xla #endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 90eb1a902127bc..623080e11fd60d 100644 --- a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -3736,13 +3736,6 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, *hlo_module.mutable_spmd_output_sharding() = *xla::ConvertSharding(spmd_output_sharding.getValue()); } - if (auto input_output_alias = - module->getAttrOfType("mhlo.input_output_alias")) { - if (std::optional input_output_alias_proto = - xla::ConvertInputOutputAlias(input_output_alias.getValue())) { - *hlo_module.mutable_input_output_alias() = *input_output_alias_proto; - } - } if (auto spmd_parameters_sharding = module->getAttrOfType( "mhlo.spmd_parameters_shardings")) { for (const auto& sharding : spmd_parameters_sharding.getValue()) { diff --git a/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir b/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir index 6ad08374e5d2e6..049456bb09e6f7 100644 --- a/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir +++ b/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir @@ -100,45 +100,3 @@ module @ModuleWithFrontendAttributes attributes { func.return %arg0 : tensor<1xf32> } } - - - -// ----- - -module attributes { -// CHECK: input_output_alias { -// CHECK-NEXT: entries { -// CHECK-NEXT: output_shape_index: 0 -// CHECK-NEXT: kind: MAY_ALIAS -// CHECK-NEXT: } -// CHECK-NEXT: entries { -// CHECK-NEXT: output_shape_index: 1 -// CHECK-NEXT: parameter_number: 1 -// CHECK-NEXT: kind: MAY_ALIAS -// CHECK-NEXT: } -// CHECK-NEXT: } - mhlo.input_output_alias = [ - { - alias = - { - kind = "may_alias", - parameter_index = array, - parameter_number = 0 : i64 - }, - output_index = array - }, - { - alias = - { - kind = "may_alias", - parameter_index = array, - parameter_number = 1 : i64 - }, - output_index = array - } -] -} { - func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32> ) -> (tensor<1xf32>, tensor<1xf32>) { - func.return %arg0, %arg1: tensor<1xf32>, tensor<1xf32> - } -} \ No newline at end of file From 3c8baf1e2edc6fd29304569883d9b9338a6afde4 Mon Sep 17 00:00:00 2001 From: Amit Sabne Date: Tue, 30 Jul 2024 17:34:42 -0700 Subject: [PATCH 298/376] Reduce search depth for dots (like we did for convs) PiperOrigin-RevId: 657788852 --- xla/service/space_to_batch_converter.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xla/service/space_to_batch_converter.cc b/xla/service/space_to_batch_converter.cc index 751b6d11dc979c..e07ee54b64fccf 100644 --- a/xla/service/space_to_batch_converter.cc +++ b/xla/service/space_to_batch_converter.cc @@ -3746,8 +3746,9 @@ bool ConvolutionVisitor::DoesConvolutionFeedUnpropagatableOp( } int64_t depth_to_use = depth; - // When we see a convolution, we reduce the depth to look further for. - if (user->opcode() == HloOpcode::kConvolution) { + // When we see a convolution/dot, we reduce the depth to look further for. + if (user->opcode() == HloOpcode::kConvolution || + user->opcode() == HloOpcode::kDot) { depth_to_use--; } From bd98c8ed91908e9bb44e7068de1a15378dbceb1d Mon Sep 17 00:00:00 2001 From: Fangrui Song Date: Tue, 30 Jul 2024 17:52:06 -0700 Subject: [PATCH 299/376] Integrate LLVM at llvm/llvm-project@51681409aeb0 Updates LLVM usage to match [51681409aeb0](https://github.com/llvm/llvm-project/commit/51681409aeb0) PiperOrigin-RevId: 657793017 --- third_party/llvm/generated.patch | 18 ------------------ third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/workspace.bzl | 4 ++-- .../tsl/third_party/llvm/generated.patch | 18 ------------------ third_party/tsl/third_party/llvm/workspace.bzl | 4 ++-- 5 files changed, 6 insertions(+), 42 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index dc682861fbb165..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,19 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp b/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp ---- a/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp -+++ b/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp -@@ -3,10 +3,10 @@ - // This test is adapted from coro-elide.cpp and splits functions into two files. - // - // RUN: split-file %s %t --// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-callee.cpp -o coro-elide-callee.bc --// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-caller.cpp -o coro-elide-caller.bc --// RUN: llvm-lto --thinlto coro-elide-callee.bc coro-elide-caller.bc -o summary --// RUN: %clang_cc1 -O2 -x ir coro-elide-caller.bc -fthinlto-index=summary.thinlto.bc -emit-llvm -o - | FileCheck %s -+// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-callee.cpp -o %t/coro-elide-callee.bc -+// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-caller.cpp -o %t/coro-elide-caller.bc -+// RUN: llvm-lto --thinlto %t/coro-elide-callee.bc %t/coro-elide-caller.bc -o %t/summary -+// RUN: %clang_cc1 -O2 -x ir %t/coro-elide-caller.bc -fthinlto-index=%t/summary.thinlto.bc -emit-llvm -o - | FileCheck %s - - //--- coro-elide-task.h - #pragma once diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index f5dd4fdd0bd288..76a13a425d79ed 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "0e6f64cd5e5a06bd78542d5541a762154546ced3" - LLVM_SHA256 = "d3b426b13175ac771a05a0908e11391be46913fc1ab7c459ae906b07b77474c0" + LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4" + LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 129c3f2e9bf708..7b492bca9f78aa 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "0458df554c1d569c034c10986069ec8fc1d58828" - SHARDY_SHA256 = "20b84eec31de9728b91901bf57aadf9faa9942a8b0383bd4bde9d588b51beeb1" + SHARDY_COMMIT = "e3462e37ae3fbf0d15836f2c657294028e1b9075" + SHARDY_SHA256 = "e9a98a5257bedbcf80a14ad1014e6645c24e3ba9352bff39b89cad3366188afb" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index dc682861fbb165..509398da979e83 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1,19 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp b/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp ---- a/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp -+++ b/clang/test/CodeGenCoroutines/coro-elide-thinlto.cpp -@@ -3,10 +3,10 @@ - // This test is adapted from coro-elide.cpp and splits functions into two files. - // - // RUN: split-file %s %t --// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-callee.cpp -o coro-elide-callee.bc --// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-caller.cpp -o coro-elide-caller.bc --// RUN: llvm-lto --thinlto coro-elide-callee.bc coro-elide-caller.bc -o summary --// RUN: %clang_cc1 -O2 -x ir coro-elide-caller.bc -fthinlto-index=summary.thinlto.bc -emit-llvm -o - | FileCheck %s -+// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-callee.cpp -o %t/coro-elide-callee.bc -+// RUN: %clang --target=x86_64-linux -std=c++20 -O2 -flto=thin -I %S -c %t/coro-elide-caller.cpp -o %t/coro-elide-caller.bc -+// RUN: llvm-lto --thinlto %t/coro-elide-callee.bc %t/coro-elide-caller.bc -o %t/summary -+// RUN: %clang_cc1 -O2 -x ir %t/coro-elide-caller.bc -fthinlto-index=%t/summary.thinlto.bc -emit-llvm -o - | FileCheck %s - - //--- coro-elide-task.h - #pragma once diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index f5dd4fdd0bd288..76a13a425d79ed 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "0e6f64cd5e5a06bd78542d5541a762154546ced3" - LLVM_SHA256 = "d3b426b13175ac771a05a0908e11391be46913fc1ab7c459ae906b07b77474c0" + LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4" + LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8" tf_http_archive( name = name, From a331af0fda577e14b13c3620293529bb999c01c9 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 30 Jul 2024 18:31:50 -0700 Subject: [PATCH 300/376] [xla:cpu] Pre-initialize KernelThunk arguments in constructor PiperOrigin-RevId: 657803003 --- xla/service/cpu/runtime/BUILD | 2 +- xla/service/cpu/runtime/kernel_thunk.cc | 33 ++++++++++++++++--------- xla/service/cpu/runtime/kernel_thunk.h | 5 ++++ 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/xla/service/cpu/runtime/BUILD b/xla/service/cpu/runtime/BUILD index b926b8539f7aaa..0592314e6bb20b 100644 --- a/xla/service/cpu/runtime/BUILD +++ b/xla/service/cpu/runtime/BUILD @@ -777,6 +777,7 @@ cc_library( "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", @@ -785,7 +786,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@llvm-project//llvm:Support", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:numbers", diff --git a/xla/service/cpu/runtime/kernel_thunk.cc b/xla/service/cpu/runtime/kernel_thunk.cc index 7ac73162df4d9b..3455f245a3338d 100644 --- a/xla/service/cpu/runtime/kernel_thunk.cc +++ b/xla/service/cpu/runtime/kernel_thunk.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/cpu/runtime/kernel_thunk.h" +#include + #define EIGEN_USE_THREADS #include @@ -25,6 +27,7 @@ limitations under the License. #include #include "absl/base/optimization.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" @@ -32,7 +35,6 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" -#include "llvm/ADT/SmallVector.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime/buffer_allocations.h" @@ -78,7 +80,19 @@ KernelThunk::KernelThunk( thread_dim_(thread_dim), min_alignment_(min_alignment), call_once_(thread_dim_ == se::ThreadDim()), - kernel_ptr_(nullptr) {} + kernel_ptr_(nullptr) { + // Initialize kernel arguments with null pointers and known buffer sizes. + // We'll use them as a template to resolve buffer addresses at run time. + kernel_args_.reserve(num_kernel_args_); + for (const BufferAllocation::Slice& buffer : arguments_buffers_) { + kernel_args_.emplace_back( + SE_HOST_KernelArg{nullptr, static_cast(buffer.size())}); + } + for (const BufferAllocation::Slice& buffer : results_buffers_) { + kernel_args_.emplace_back( + SE_HOST_KernelArg{nullptr, static_cast(buffer.size())}); + } +} tsl::AsyncValueRef KernelThunk::Execute( const ExecuteParams& params) { @@ -90,31 +104,28 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - // We use `llvm::SmallVector` instead of `absl::InlinedVector` because - // it allows to resize a vector without zero-initializing storage. - llvm::SmallVector kernel_args; - kernel_args.resize_for_overwrite(num_kernel_args_); - + absl::InlinedVector kernel_args = kernel_args_; SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data(); + const BufferAllocations* allocations = params.buffer_allocations; for (BufferAllocation::Slice& buffer : arguments_buffers_) { if constexpr (ShouldCheckBufferSlices()) { TF_ASSIGN_OR_RETURN(auto mem, allocations->GetDeviceAddress(buffer)); - *kernel_args_ptr++ = SE_HOST_KernelArg{mem.opaque(), mem.size()}; + kernel_args_ptr++->data = mem.opaque(); } else { auto mem = allocations->GetDeviceAddressUnchecked(buffer); - *kernel_args_ptr++ = SE_HOST_KernelArg{mem.opaque(), mem.size()}; + kernel_args_ptr++->data = mem.opaque(); } } for (BufferAllocation::Slice& buffer : results_buffers_) { if constexpr (ShouldCheckBufferSlices()) { TF_ASSIGN_OR_RETURN(auto mem, allocations->GetDeviceAddress(buffer)); - *kernel_args_ptr++ = SE_HOST_KernelArg{mem.opaque(), mem.size()}; + kernel_args_ptr++->data = mem.opaque(); } else { auto mem = allocations->GetDeviceAddressUnchecked(buffer); - *kernel_args_ptr++ = SE_HOST_KernelArg{mem.opaque(), mem.size()}; + kernel_args_ptr++->data = mem.opaque(); } } diff --git a/xla/service/cpu/runtime/kernel_thunk.h b/xla/service/cpu/runtime/kernel_thunk.h index 80bf16a4573916..871176ba73ec5b 100644 --- a/xla/service/cpu/runtime/kernel_thunk.h +++ b/xla/service/cpu/runtime/kernel_thunk.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" @@ -84,6 +85,10 @@ class KernelThunk final : public Thunk { absl::Mutex mutex_; std::optional kernel_ ABSL_GUARDED_BY(mutex_); std::atomic kernel_ptr_; // pointer to `kernel_` + + // Pre-initialized kernel arguments that are updated with memory addresses + // before the kernel launch. + absl::InlinedVector kernel_args_; }; } // namespace xla::cpu From a6e60a51fcc485c179a7448c431b432914caae7e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 30 Jul 2024 19:27:10 -0700 Subject: [PATCH 301/376] [xla:cpu] Implement Thunk::OkExecuteEvent in header file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_SelectAndScatterF32/128/process_time 395µs ± 2% 379µs ± 2% -3.89% BM_SelectAndScatterF32/256/process_time 1.65ms ± 5% 1.55ms ± 1% -5.84% BM_SelectAndScatterF32/512/process_time 7.42ms ± 4% 7.14ms ± 5% -3.85% PiperOrigin-RevId: 657815613 --- xla/service/cpu/runtime/kernel_thunk.cc | 2 +- xla/service/cpu/runtime/thunk.cc | 6 +++--- xla/service/cpu/runtime/thunk.h | 6 +++++- xla/tsl/concurrency/async_value_ref.h | 4 +++- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/xla/service/cpu/runtime/kernel_thunk.cc b/xla/service/cpu/runtime/kernel_thunk.cc index 3455f245a3338d..88847041c3b51f 100644 --- a/xla/service/cpu/runtime/kernel_thunk.cc +++ b/xla/service/cpu/runtime/kernel_thunk.cc @@ -140,7 +140,7 @@ tsl::AsyncValueRef KernelThunk::Execute( // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk // initialization stage. - se::host::HostKernel* kernel = kernel_ptr_.load(); + se::host::HostKernel* kernel = kernel_ptr_.load(std::memory_order_relaxed); // Because thunks are owned by a parent CpuExecutable, we can safely assume // that kernel pointer will not change after we find it the first time. diff --git a/xla/service/cpu/runtime/thunk.cc b/xla/service/cpu/runtime/thunk.cc index 455c940e264f3b..5438e60b33d844 100644 --- a/xla/service/cpu/runtime/thunk.cc +++ b/xla/service/cpu/runtime/thunk.cc @@ -150,13 +150,13 @@ Thunk::CustomCallExecuteParams::CustomCallExecuteParams( allocator(allocator), ffi_execution_context(ffi_execution_context) {} -tsl::AsyncValueRef Thunk::OkExecuteEvent() { - static tsl::AsyncValueOwningRef* event = [] { +const tsl::AsyncValueOwningRef* Thunk::OkEvent() { + static tsl::AsyncValueOwningRef* owner = [] { auto* storage = new tsl::internal::AsyncValueStorage(); return new tsl::AsyncValueOwningRef( tsl::MakeAvailableAsyncValueRef(*storage)); }(); - return event->AsRef(); + return owner; } Thunk::ExecuteState::ExecuteState(int64_t num_tasks) diff --git a/xla/service/cpu/runtime/thunk.h b/xla/service/cpu/runtime/thunk.h index 210d19937b2173..67ed37c43e3917 100644 --- a/xla/service/cpu/runtime/thunk.h +++ b/xla/service/cpu/runtime/thunk.h @@ -288,7 +288,9 @@ class Thunk { // Returns non-reference-counted async value ref for thunks executed in the // caller thread to avoid reference counting overhead. - static tsl::AsyncValueRef OkExecuteEvent(); + static tsl::AsyncValueRef OkExecuteEvent() { + return OkEvent()->AsRef(); + } // Thunk execution must be asynchronous and never block the caller thread, // especially waiting for work submitted into the `intra_op_threadpool`, @@ -329,6 +331,8 @@ class Thunk { } private: + static const tsl::AsyncValueOwningRef* OkEvent(); + Kind kind_; Info info_; }; diff --git a/xla/tsl/concurrency/async_value_ref.h b/xla/tsl/concurrency/async_value_ref.h index ca1f4133dad564..625c9085323b5d 100644 --- a/xla/tsl/concurrency/async_value_ref.h +++ b/xla/tsl/concurrency/async_value_ref.h @@ -423,7 +423,9 @@ class AsyncValuePtr { T& operator*() const { return get(); } explicit operator bool() const { return value_ != nullptr; } - bool operator!=(std::nullptr_t) const { return value_ != nullptr; } + bool operator==(const AsyncValuePtr& p) const { return value_ == p.value_; } + bool operator!=(const AsyncValuePtr& p) const { return value_ != p.value_; } + AsyncValuePtr& operator=(std::nullptr_t) { value_ = nullptr; return *this; From 55c0e59e58205bc197fc9909429274925f7d8e3a Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Tue, 30 Jul 2024 19:36:58 -0700 Subject: [PATCH 302/376] [IFRT] Improve ifrt-verify-donation pass to catch more cases of use after donation. Before this change the pass would not emit an error when an array was donated, and later used as not donated. Moreover, the change improves error logs by including op callee name and op locations. PiperOrigin-RevId: 657817957 --- .../ifrt/ir/tests/ifrt_verify_donation.mlir | 90 +++++++++++++++++-- .../transforms/ifrt_verify_donation_pass.cc | 53 +++++++---- 2 files changed, 119 insertions(+), 24 deletions(-) diff --git a/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir b/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir index 3f6050206495cf..edaff924361072 100644 --- a/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir @@ -46,7 +46,7 @@ module @donate_to_two_calls_error { attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] {io_aliases=[array]} : (!array) -> !array - // expected-error @+1 {{'ifrt.Call' op input #0 already donated.}} + // expected-error @+1 {{'ifrt.Call' op input #0 of @identity was already donated}} %1, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1] {io_aliases=[array]} : (!array) -> !array return %0, %1 : !array, !array @@ -78,7 +78,7 @@ module @program_arg_not_donated_error { module @arg_both_donated_and_not_donated_error { func.func @main(%arg0: !array0 {ifrt.donated}) -> !array0 attributes {ifrt.function} { - // expected-error @+1 {{'ifrt.Call' op input #0 is both donated and not donated.}} + // expected-error @+1 {{'ifrt.Call' op input #0 of @add_two_args was already donated}} %0, %ctrl_0 = ifrt.Call @add_two_args(%arg0, %arg0) on devices [0,1] {io_aliases=[array]} : (!array0, !array0) -> !array0 return %0 : !array0 @@ -101,7 +101,23 @@ module @donate_to_two_reshards_error { func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array1, !array1) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 - // expected-error @+1 {{'ifrt.Reshard' op input #0 already donated.}} + // expected-error @+1 {{'ifrt.Reshard' op input #0 of op}} + %1, %ctrl_1 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 + return %0, %1 : !array1, !array1 + } +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> +module @donate_to_two_reshards_error { + func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array1, !array1) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 + // expected-error @+1 {{'ifrt.Reshard' op input #0 of op}} %1, %ctrl_1 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 return %0, %1 : !array1, !array1 } @@ -118,7 +134,7 @@ module @donate_to_reshard_and_call_error { attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] {io_aliases=[array]} : (!array0) -> !array0 - // expected-error @+1 {{'ifrt.Reshard' op input #0 already donated.}} + // expected-error @+1 {{'ifrt.Reshard' op input #0 of op}} %1, %ctrl_1 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 return %0, %1 : !array0, !array1 } @@ -138,7 +154,7 @@ module @donate_to_two_copy_arrays_error { func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array1, !array1) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.CopyArrays(%arg0) {donated=true} : (!array0) -> !array1 - // expected-error @+1 {{'ifrt.CopyArrays' op input #0 already donated.}} + // expected-error @+1 {{'ifrt.CopyArrays' op input #0 of op}} %1, %ctrl_1 = ifrt.CopyArrays(%arg0) {donated=true} : (!array0) -> !array1 return %0, %1 : !array1, !array1 } @@ -169,7 +185,7 @@ module @donate_to_reshard_and_call_error { attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] {io_aliases=[array]} : (!array) -> !array - // expected-error @+1 {{'ifrt.RemapArrays' op input #1 already donated.}} + // expected-error @+1 {{'ifrt.RemapArrays' op input #1 of op}} %1 = ifrt.RemapArrays(%0, %arg0) mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] @@ -181,3 +197,65 @@ module @donate_to_reshard_and_call_error { return %arg0 : tensor<2xi32> } } + +// ----- + +!array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +module @call_after_donation_error { + func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {io_aliases=[array]} : (!array) -> !array + // expected-error @+1 {{'ifrt.Call' op input #0 of @identity was already donated}} + %1, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1] + : (!array) -> !array + return %0, %1 : !array, !array + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> +module @reshard_with_already_donated_array_error { + func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array0, !array1) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {io_aliases=[array]} : (!array0) -> !array0 + // expected-error @+1 {{'ifrt.Reshard' op input #0 of op}} + %1, %ctrl_1 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + return %0, %1 : !array0, !array1 + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> +module @copy_arrays_with_already_donated_array_error { + func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array0, !array1) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {io_aliases=[array]} : (!array0) -> !array0 + // expected-error @+1 {{'ifrt.CopyArrays' op input #0 of op}} + %1, %ctrl_1 = ifrt.CopyArrays(%arg0) : (!array0) -> !array1 + return %0, %1 : !array0, !array1 + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} diff --git a/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc b/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc index be11610d826884..bdfcf787d61d87 100644 --- a/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc +++ b/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" @@ -63,7 +64,7 @@ class IfrtVerifyDonationPass void IfrtVerifyDonationPass::runOnOperation() { mlir::ModuleOp module_op = getOperation(); - llvm::DenseSet donated_values; + llvm::DenseMap donated_value_to_op; mlir::WalkResult result = module_op.walk([&](mlir::Operation* op) -> mlir::WalkResult { auto result = @@ -78,44 +79,60 @@ void IfrtVerifyDonationPass::runOnOperation() { io_alias.asArrayRef(); donated_input_idxs.insert(io_alias_as_array[0]); auto donated_value = op.getInputs()[io_alias_as_array[0]]; - if (!donated_values.insert(donated_value).second) { + auto donated_it = + donated_value_to_op.try_emplace(donated_value, op); + if (!donated_it.second) { op.emitOpError() << "input #" << io_alias_as_array[0] - << " already donated."; + << " of " << op.getCalleeAttr() + << " was already donated to the op at " + << donated_it.first->second->getLoc(); return mlir::failure(); } - if (mlir::failed( VerifyIfInputAndDonated(op, donated_value))) { return mlir::failure(); } } - // Verify that an input is not both donated and not donated. + // Verify non-donated inputs after donated inputs have been + // added to also catch instances such as + // `ifrt.Call(%arg0 {ifrt.donated}, %arg0})`. for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { - if (donated_values.contains(input) && - !donated_input_idxs.contains(idx)) { - op.emitOpError() << "input #" << idx - << " is both donated and not donated."; - return mlir::failure(); + if (!donated_input_idxs.contains(idx)) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() + << "input #" << idx << " of " << op.getCalleeAttr() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } } } return mlir::success(); }) .Case([&](auto& op) { + // Verify that no inputs have already been donated. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() + << "input #" << idx << " of op at " << op.getLoc() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } + } if (op.getDonated()) { - for (const auto [idx, input] : - llvm::enumerate(op.getInputs())) { - if (donated_values.contains(input)) { - op.emitOpError() << "input #" << idx << " already donated."; - return mlir::failure(); - } + // Add the donated inputs to the map and verify that all the + // donated inputs are also donated to the main func. + for (const auto input : op.getInputs()) { + donated_value_to_op.try_emplace(input, op); if (mlir::failed(VerifyIfInputAndDonated(op, input))) { return mlir::failure(); } } - donated_values.insert(op.getInputs().begin(), - op.getInputs().end()); } return mlir::success(); }) From 738da13cf13ebf1e5a87b6fcd9fa62db9bf61c0e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 30 Jul 2024 21:25:34 -0700 Subject: [PATCH 303/376] [xla:cpu] Use fast IsOkExecuteEvent check in ThunkExecutor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_SelectAndScatterF32/128/process_time 376µs ± 1% 374µs ± 2% ~ BM_SelectAndScatterF32/256/process_time 1.56ms ± 3% 1.53ms ± 3% -1.58% BM_SelectAndScatterF32/512/process_time 7.07ms ± 4% 6.98ms ± 1% -1.30% PiperOrigin-RevId: 657849123 --- xla/service/cpu/runtime/thunk.h | 8 ++++++++ xla/service/cpu/runtime/thunk_executor.cc | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/xla/service/cpu/runtime/thunk.h b/xla/service/cpu/runtime/thunk.h index 67ed37c43e3917..5bf8cfb8baf01d 100644 --- a/xla/service/cpu/runtime/thunk.h +++ b/xla/service/cpu/runtime/thunk.h @@ -292,6 +292,14 @@ class Thunk { return OkEvent()->AsRef(); } + static bool IsOkExecuteEvent(tsl::AsyncValuePtr event) { + return event == OkEvent()->AsPtr(); + } + + static bool IsOkExecuteEvent(const tsl::AsyncValueRef& event) { + return IsOkExecuteEvent(event.AsPtr()); + } + // Thunk execution must be asynchronous and never block the caller thread, // especially waiting for work submitted into the `intra_op_threadpool`, // because thunks themselves are executed on the same thread pool. diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index 4173f687c13eb5..c3e8971e1a722d 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -168,6 +168,11 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { Thunk& thunk = *thunk_sequence_[i]; auto execute_event = thunk.Execute(params); + // Fast path for thunks executed inline and returned OkExecuteEvent. + if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) { + continue; + } + // If thunk execution is not completed yet, attach a continuation to // resume sequential execution starting from the next thunk. if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { @@ -200,6 +205,11 @@ void ThunkExecutor::ResumeExecuteSequential( Thunk& thunk = *thunk_sequence_[i]; auto execute_event = thunk.Execute(params); + // Fast path for thunks executed inline and returned OkExecuteEvent. + if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) { + continue; + } + // If thunk execution is not completed yet, attach a continuation to // resume sequential execution starting from the next thunk. if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { From 40f1d3ccb288ab91650c35b3ddadf0d4ca527d7b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 30 Jul 2024 21:39:57 -0700 Subject: [PATCH 304/376] [xla:cpu] Update ThunkExecutor to sequential execution mode if all thunks use small buffers PiperOrigin-RevId: 657852861 --- xla/service/cpu/runtime/thunk_executor.cc | 22 +++++++-- xla/service/cpu/runtime/thunk_executor.h | 20 +++++++- .../cpu/runtime/thunk_executor_test.cc | 46 +++++++++++++------ 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index c3e8971e1a722d..26c084e7e8c5e4 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -41,8 +41,10 @@ limitations under the License. namespace xla::cpu { ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, - std::vector nodes_defs) + std::vector nodes_defs, + const ThunkExecutor::Options& options) : thunk_sequence_(std::move(thunk_sequence)), + options_(options), nodes_defs_(std::move(nodes_defs)), is_sequential_(true) { for (NodeId i = 0; i < nodes_defs_.size(); ++i) { @@ -66,11 +68,21 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, is_sequential_ &= (absl::c_count(nodes_defs_[i].in_edges, i - 1) != 0); } + // Maybe mark execution as sequential if all thunks use small buffers. + auto uses_small_buffers = [&](const std::unique_ptr& thunk) { + return absl::c_all_of(thunk->buffer_uses(), [&](const BufferUse& use) { + return use.slice().size() <= options.execute_sequential_buffer_threshold; + }); + }; + + bool small_buffers = absl::c_all_of(thunk_sequence_, uses_small_buffers); + is_sequential_ |= small_buffers; + VLOG(2) << absl::StreamFormat( "Constructed ThunkExecutor with %d nodes: #source_nodes=%d " - "#sink_nodes=%d, #erased_edges=%d, is_sequential=%v", + "#sink_nodes=%d, #erased_edges=%d, is_sequential=%v, small_buffers=%v", nodes_defs_.size(), source_.size(), sink_.size(), num_erased_edges, - is_sequential_); + is_sequential_, small_buffers); // Sanity check that all vectors are empty or all vectors are non-empty. DCHECK((!source_.empty() && !sink_.empty() && !thunk_sequence_.empty()) || @@ -78,7 +90,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, } absl::StatusOr ThunkExecutor::Create( - ThunkSequence thunk_sequence) { + ThunkSequence thunk_sequence, const ThunkExecutor::Options& options) { std::vector defs(thunk_sequence.size()); std::vector buffer_rwsets(thunk_sequence.size()); @@ -106,7 +118,7 @@ absl::StatusOr ThunkExecutor::Create( } } - return ThunkExecutor(std::move(thunk_sequence), std::move(defs)); + return ThunkExecutor(std::move(thunk_sequence), std::move(defs), options); } ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor, diff --git a/xla/service/cpu/runtime/thunk_executor.h b/xla/service/cpu/runtime/thunk_executor.h index a48dd843871d4c..8965a7a51652a4 100644 --- a/xla/service/cpu/runtime/thunk_executor.h +++ b/xla/service/cpu/runtime/thunk_executor.h @@ -36,6 +36,17 @@ limitations under the License. namespace xla::cpu { +namespace internal { +// Clang does not allow defining a nested struct with member initializer, as +// a workaround we define a struct in internal namespace and create an alias. +struct ThunkExecutorOptions { + // If all thunks in a sequence use buffers of size less than or equal to + // `execute_sequential_buffer_threshold`, we mark execution as sequential, as + // concurrency overheads will likely dominate the overall execution time. + size_t execute_sequential_buffer_threshold = 512; +}; +} // namespace internal + // A dataflow-style (run when ready) executor for a ThunkSequence that depends // on buffer uses to build a DAG defining execution order. At run time executes // thunks concurrently in a given thread pool. @@ -44,6 +55,7 @@ class ThunkExecutor { using BufferUses = Thunk::BufferUses; using ResourceUses = Thunk::ResourceUses; using ExecuteEvent = Thunk::ExecuteEvent; + using Options = internal::ThunkExecutorOptions; // Nodes identified by their index in the captured ThunkSequence. using NodeId = int64_t; @@ -53,7 +65,8 @@ class ThunkExecutor { ThunkExecutor(ThunkExecutor&&) = default; ThunkExecutor& operator=(ThunkExecutor&&) = default; - static absl::StatusOr Create(ThunkSequence thunk_sequence); + static absl::StatusOr Create( + ThunkSequence thunk_sequence, const Options& options = Options()); // NodeDef defines an execution order for all thunks in a sequence. struct NodeDef { @@ -123,7 +136,8 @@ class ThunkExecutor { absl::Status abort_status ABSL_GUARDED_BY(abort_mutex); }; - ThunkExecutor(ThunkSequence thunk_sequence, std::vector nodes_defs); + ThunkExecutor(ThunkSequence thunk_sequence, std::vector nodes_defs, + const Options& options); // Executes thunks sequentially starting from the first thunk in the sequence. tsl::AsyncValueRef ExecuteSequential( @@ -157,6 +171,8 @@ class ThunkExecutor { int64_t TransitiveReduction(); ThunkSequence thunk_sequence_; + Options options_; + std::vector nodes_defs_; std::vector source_; diff --git a/xla/service/cpu/runtime/thunk_executor_test.cc b/xla/service/cpu/runtime/thunk_executor_test.cc index 26ba2e553a2b4d..2bbb932a4a432c 100644 --- a/xla/service/cpu/runtime/thunk_executor_test.cc +++ b/xla/service/cpu/runtime/thunk_executor_test.cc @@ -211,6 +211,13 @@ AddI32Thunk::ResourceUses AddI32Thunk::resource_uses() const { : ResourceUses{}; } +static ThunkExecutor::Options OptionsForTest() { + // Override small buffers threshold to make sure that we test all execution + // paths, because in test we always use small buffers below the default + // threshold of `512`. + return ThunkExecutor::Options{/*execute_sequential_buffer_threshold=*/0}; +} + TEST(ThunkExecutorTest, DependencyOrdering) { BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); @@ -223,8 +230,9 @@ TEST(ThunkExecutorTest, DependencyOrdering) { sequence.push_back(AddI32Thunk::Create("b", {slice1}, {slice1})); sequence.push_back(AddI32Thunk::Create("c", {slice2}, {slice2})); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); EXPECT_FALSE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0, 1)); @@ -240,8 +248,9 @@ TEST(ThunkExecutorTest, SequentialOrdering) { sequence.push_back(AddI32Thunk::Create("b", {slice}, {slice})); sequence.push_back(AddI32Thunk::Create("c", {slice}, {slice})); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); EXPECT_TRUE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0)); @@ -262,8 +271,9 @@ TEST(ThunkExecutorTest, ResourceOrdering) { /*trace=*/nullptr, /*use_shared_resource=*/true)); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); EXPECT_TRUE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0)); @@ -279,8 +289,9 @@ TEST(ThunkExecutorTest, TransitiveReduction) { sequence.push_back(AddI32Thunk::Create("b", {slice}, {slice})); sequence.push_back(AddI32Thunk::Create("c", {slice}, {slice})); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); EXPECT_THAT(executor.source(), ElementsAre(0)); EXPECT_THAT(executor.sink(), ElementsAre(2)); @@ -305,8 +316,9 @@ TEST(ThunkExecutorTest, Execute) { sequence.push_back(AddI32Thunk::Create("b", {slice1}, {slice1}, &trace)); sequence.push_back(AddI32Thunk::Create("c", {slice2}, {slice2}, &trace)); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); std::vector data(20, 1); // shared src and dst allocation @@ -472,8 +484,9 @@ TEST_P(ThunkExecutorStressTest, Execute) { GenerateThunkSequence(/*num_elements=*/1024, num_thunks, shared_resource_use, inject_errors)); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(g->sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(g->sequence), OptionsForTest())); BufferAllocations allocations(g->buffers); Thunk::ExecuteParams params = {nullptr, &allocations, nullptr, device(), @@ -517,7 +530,8 @@ static void BM_SequentialThunkExecutor(benchmark::State& state) { /*shared_resource_use=*/SharedResourceUse::kAll, /*inject_errors=*/false) .value(); - auto e = ThunkExecutor::Create(std::move(g->sequence)).value(); + auto e = + ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); BufferAllocations allocations(g->buffers); Thunk::ExecuteParams params = {nullptr, &allocations}; @@ -536,7 +550,8 @@ static void BM_SyncThunkExecutor(benchmark::State& state) { /*shared_resource_use=*/SharedResourceUse::kNo, /*inject_errors=*/false) .value(); - auto e = ThunkExecutor::Create(std::move(g->sequence)).value(); + auto e = + ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); BufferAllocations allocations(g->buffers); Thunk::ExecuteParams params = {nullptr, &allocations}; @@ -559,7 +574,8 @@ static void BM_AsyncThunkExecutor(benchmark::State& state) { /*shared_resource_use=*/SharedResourceUse::kNo, /*inject_errors=*/false) .value(); - auto e = ThunkExecutor::Create(std::move(g->sequence)).value(); + auto e = + ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); BufferAllocations allocations(g->buffers); From 74e77b4b45f27891b1718839caabb447e688c765 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jul 2024 22:41:31 -0700 Subject: [PATCH 305/376] Automated Code Change PiperOrigin-RevId: 657866327 --- xla/python/py_array.cc | 1 - xla/python/py_array.h | 1 + xla/python/refine_polymorphic_shapes.cc | 1 + 3 files changed, 2 insertions(+), 1 deletion(-) diff --git a/xla/python/py_array.cc b/xla/python/py_array.cc index 8d00206e38ee30..8fd0aa4e6d7370 100644 --- a/xla/python/py_array.cc +++ b/xla/python/py_array.cc @@ -74,7 +74,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/python/py_client.h" #include "xla/python/py_device.h" #include "xla/python/py_values.h" diff --git a/xla/python/py_array.h b/xla/python/py_array.h index cb76a5fcb90272..015c61c391146a 100644 --- a/xla/python/py_array.h +++ b/xla/python/py_array.h @@ -27,6 +27,7 @@ limitations under the License. #include // placeholder for index annotation headers +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" diff --git a/xla/python/refine_polymorphic_shapes.cc b/xla/python/refine_polymorphic_shapes.cc index cbd42e928ef576..e358b2e549a8e2 100644 --- a/xla/python/refine_polymorphic_shapes.cc +++ b/xla/python/refine_polymorphic_shapes.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Bytecode/BytecodeWriter.h" From 4bc4e399e708893c9bfa9adc73e4e18e057bfde6 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Tue, 30 Jul 2024 23:29:47 -0700 Subject: [PATCH 306/376] [xla:cpu-oneDNN] Fix crashes when oneDNN matmul/convolution/layernorm tests were run with libc++ hardened mode. operands_stack_alloca arrays in the emitters weren't initialized properly. + Minor refactoring. PiperOrigin-RevId: 657877501 --- xla/service/cpu/ir_emitter.cc | 58 +++++++++------------ xla/service/cpu/ir_emitter.h | 7 +++ xla/service/cpu/tests/onednn_matmul_test.cc | 2 +- 3 files changed, 32 insertions(+), 35 deletions(-) diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index 13ede90a9192ab..d8bae800533f27 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -2643,6 +2643,22 @@ absl::Status IrEmitter::HandleTopK(HloInstruction* hlo) { } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +// Emits operands alloca vector for oneDNN custom calls. +std::vector IrEmitter::EmitOneDnnOperandsAlloca( + HloInstruction* custom_call, llvm::Value*& args_val, int& arg_indx) { + std::vector operands_stack_alloca; + const int num_operands = custom_call->operand_count(); + operands_stack_alloca.reserve(num_operands); + for (int i = 0; i < num_operands; ++i) { + llvm_ir::IrArray ir_array(GetIrArrayFor(custom_call->operand(i))); + StackAlloca stack_alloca = GetAllocaAndEmitMemrefInfo(*b(), ir_array); + args_val = b()->CreateInsertValue(args_val, stack_alloca.value, arg_indx++); + operands_stack_alloca.push_back(std::move(stack_alloca)); + } + return operands_stack_alloca; +} + absl::Status IrEmitter::HandleOneDnnMatMulCalls( HloInstruction* custom_call, std::string runtime_symbol_name) { // We would like to emit LLVM IR for the following function call @@ -2684,7 +2700,6 @@ absl::Status IrEmitter::HandleOneDnnMatMulCalls( args_val = b()->CreateInsertValue(args_val, run_opts_val, arg_indx++); // Insert OneDnnMatMulConfig. - auto typed_custom_call = Cast(custom_call); auto backend_config = typed_custom_call->backend_config(); OneDnnMatMulConfig matmul_config; @@ -2696,17 +2711,8 @@ absl::Status IrEmitter::HandleOneDnnMatMulCalls( args_val = b()->CreateInsertValue(args_val, matmul_config_val, arg_indx++); // Insert operands. - std::vector operands_stack_alloca; - operands_stack_alloca.reserve(num_operands); - absl::c_transform(custom_call->operands(), operands_stack_alloca.begin(), - [this](HloInstruction* instr) { - llvm_ir::IrArray ir_array(GetIrArrayFor(instr)); - return GetAllocaAndEmitMemrefInfo(*b(), ir_array); - }); - for (int i = 0; i < num_operands; ++i) { - args_val = b()->CreateInsertValue(args_val, operands_stack_alloca[i].value, - arg_indx++); - } + auto operands_stack_alloca = + EmitOneDnnOperandsAlloca(custom_call, args_val, arg_indx); TF_RET_CHECK(nargs == arg_indx) << "Number of arguments don't equal the last argument index."; @@ -2812,17 +2818,8 @@ absl::Status IrEmitter::HandleOneDnnConvolution(HloInstruction* custom_call) { b()->CreateGlobalStringPtr(llvm_ir::AsStringRef(str_config)); args_val = b()->CreateInsertValue(args_val, conv_config_val, arg_indx++); - std::vector operands_stack_alloca; - operands_stack_alloca.reserve(num_operands); - absl::c_transform(custom_call->operands(), operands_stack_alloca.begin(), - [this](HloInstruction* instr) { - llvm_ir::IrArray ir_array(GetIrArrayFor(instr)); - return GetAllocaAndEmitMemrefInfo(*b(), ir_array); - }); - for (int i = 0; i < num_operands; ++i) { - args_val = b()->CreateInsertValue(args_val, operands_stack_alloca[i].value, - arg_indx++); - } + auto operands_stack_alloca = + EmitOneDnnOperandsAlloca(custom_call, args_val, arg_indx); TF_RET_CHECK(nargs == arg_indx) << "Number of arguments don't equal the last argument index."; @@ -2891,17 +2888,10 @@ absl::Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) { args_val = b()->CreateInsertValue(args_val, ln_config_val, arg_indx++); // Insert operands. - std::vector operands_stack_alloca; - operands_stack_alloca.reserve(num_operands); - absl::c_transform(custom_call->operands(), operands_stack_alloca.begin(), - [this](HloInstruction* instr) { - llvm_ir::IrArray ir_array(GetIrArrayFor(instr)); - return GetAllocaAndEmitMemrefInfo(*b(), ir_array); - }); - for (int i = 0; i < num_operands; ++i) { - args_val = b()->CreateInsertValue(args_val, operands_stack_alloca[i].value, - arg_indx++); - } + auto operands_stack_alloca = + EmitOneDnnOperandsAlloca(custom_call, args_val, arg_indx); + TF_RET_CHECK(nargs == arg_indx) + << "Number of arguments don't equal the last argument index."; llvm::Value* args_ptr = llvm_ir::EmitAllocaAtFunctionEntry(ptr_array_type, "layernorm.args", b()); diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index 4ed7854f48a610..0e6b1a3d6c0269 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -59,6 +59,10 @@ limitations under the License. #include "xla/service/name_uniquer.h" #include "xla/xla_data.pb.h" +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/onednn_memory_util.h" +#endif + namespace xla { namespace cpu { @@ -320,6 +324,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, absl::Status HandleAllReduceSingleReplica(HloInstruction* crs); absl::Status HandleAllReduceMultipleReplica(HloInstruction* crs); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + std::vector EmitOneDnnOperandsAlloca(HloInstruction* custom_call, + llvm::Value*& args_val, + int& arg_indx); absl::Status HandleOneDnnMatMulCalls(HloInstruction* hlo, std::string runtime_symbol_name); absl::Status HandleOneDnnSoftmax(HloInstruction* hlo); diff --git a/xla/service/cpu/tests/onednn_matmul_test.cc b/xla/service/cpu/tests/onednn_matmul_test.cc index 389716c4ddef95..c31ed5c2fc6ac7 100644 --- a/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/xla/service/cpu/tests/onednn_matmul_test.cc @@ -1541,7 +1541,7 @@ TEST_F(MatmulTest, ConsecutiveBinaryAdd) { TEST_F(MatmulTest, BroadcastedAddAfterFusion) { const char* matmul_module_str = R"( - HloModule matmul.nonscalar.test.1 + HloModule matmul.nonscalar.test ENTRY matmul.nonscalar.test.f32 { arg.0 = f32[16,400,500] parameter(0) arg.1 = f32[16,500,3] parameter(1) From d73be28f460bb673da599a089545c89057bac751 Mon Sep 17 00:00:00 2001 From: Fangrui Song Date: Tue, 30 Jul 2024 23:56:45 -0700 Subject: [PATCH 307/376] Integrate LLVM at llvm/llvm-project@d92a484e6f5c Updates LLVM usage to match [d92a484e6f5c](https://github.com/llvm/llvm-project/commit/d92a484e6f5c) PiperOrigin-RevId: 657883563 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/temporary.patch | 15 +++++++++++++++ third_party/shardy/workspace.bzl | 4 ++-- third_party/tsl/third_party/llvm/workspace.bzl | 4 ++-- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 76a13a425d79ed..9345d8db8d67ef 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4" - LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8" + LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575" + LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index e69de29bb2d1d6..9bf90881410570 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -0,0 +1,15 @@ +diff --git i/third_party/llvm/workspace.bzl w/third_party/llvm/workspace.bzl +index 76a13a4..9345d8d 100644 +--- i/third_party/llvm/workspace.bzl ++++ w/third_party/llvm/workspace.bzl +@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") + + def repo(name): + """Imports LLVM.""" +- LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4" +- LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8" ++ LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575" ++ LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824" + + tf_http_archive( + name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 7b492bca9f78aa..3f82df3eee7669 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "e3462e37ae3fbf0d15836f2c657294028e1b9075" - SHARDY_SHA256 = "e9a98a5257bedbcf80a14ad1014e6645c24e3ba9352bff39b89cad3366188afb" + SHARDY_COMMIT = "c87ce5b404305927c7a169b305ba0dc1c304e4ce" + SHARDY_SHA256 = "2fa411cfb31f351f2cdad997db0ccb8f9898bad3421f2a78889703bb75bd054c" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 76a13a425d79ed..9345d8db8d67ef 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4" - LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8" + LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575" + LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824" tf_http_archive( name = name, From 2b72d4f65b47caad3d611f7d9a747b6beec0f1ae Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 00:15:06 -0700 Subject: [PATCH 308/376] Automated Code Change PiperOrigin-RevId: 657888564 --- xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h | 6 +++--- xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index 293bc4016ed803..7c06b26c8aa785 100644 --- a/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -62,8 +62,8 @@ class MarkEventReadyOnExit { MarkEventReadyOnExit(const MarkEventReadyOnExit&) = delete; MarkEventReadyOnExit& operator=(const MarkEventReadyOnExit&) = delete; - MarkEventReadyOnExit(MarkEventReadyOnExit&&) = default; - MarkEventReadyOnExit& operator=(MarkEventReadyOnExit&&) = default; + MarkEventReadyOnExit(MarkEventReadyOnExit&&) noexcept = default; + MarkEventReadyOnExit& operator=(MarkEventReadyOnExit&&) noexcept = default; ~MarkEventReadyOnExit() { if (event_) event_.SetStateConcrete(); @@ -163,7 +163,7 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { DonationTransaction(const DonationTransaction&) = delete; DonationTransaction& operator=(const DonationTransaction&) = delete; DonationTransaction(DonationTransaction&&) = default; - DonationTransaction& operator=(DonationTransaction&& other) { + DonationTransaction& operator=(DonationTransaction&& other) noexcept { Abort(); buffer_ = other.buffer_; diff --git a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h index 4bfc1c57aed269..8d22bd891e6faf 100644 --- a/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h +++ b/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h @@ -128,8 +128,9 @@ class TrackedTfrtCpuDeviceBuffer { absl::AnyInvocable on_delete_callback = nullptr); // Move-only. - TrackedTfrtCpuDeviceBuffer(TrackedTfrtCpuDeviceBuffer&&) = default; - TrackedTfrtCpuDeviceBuffer& operator=(TrackedTfrtCpuDeviceBuffer&&) = default; + TrackedTfrtCpuDeviceBuffer(TrackedTfrtCpuDeviceBuffer&&) noexcept = default; + TrackedTfrtCpuDeviceBuffer& operator=(TrackedTfrtCpuDeviceBuffer&&) noexcept = + default; TrackedTfrtCpuDeviceBuffer(const TrackedTfrtCpuDeviceBuffer&) = delete; TrackedTfrtCpuDeviceBuffer& operator=(const TrackedTfrtCpuDeviceBuffer&) = delete; From 143a0b36425a620da4d433f0ce71b777e11c35cb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 00:55:59 -0700 Subject: [PATCH 309/376] Reverts 9bb18711c56c5538e49f4a38d44c06d8c397a7c1 PiperOrigin-RevId: 657897820 --- xla/translate/hlo_to_mhlo/BUILD | 1 - .../hlo_to_mhlo/hlo_function_importer.cc | 37 +++++++ .../hlo_to_mhlo/hlo_function_importer.h | 7 ++ .../hlo_to_mhlo/hlo_module_importer.cc | 4 + .../hlo_to_mhlo/tests/module_attributes.hlo | 13 +++ xla/translate/mhlo_to_hlo/BUILD | 1 + .../mhlo_to_hlo/attribute_exporter.cc | 97 +++++++++++++++++++ .../mhlo_to_hlo/attribute_exporter.h | 4 + xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 7 ++ .../mhlo_to_hlo/tests/module_attributes.mlir | 42 ++++++++ 10 files changed, 212 insertions(+), 1 deletion(-) diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD index cb3d691eb839a0..043dbf0cc801d7 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD @@ -83,7 +83,6 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index 98fcd5172728f6..7d2a929822b1fb 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -2500,6 +2501,42 @@ absl::Status HloFunctionImporter::ConvertShapeToMlirLayout( return Internal("Couldn't convert layout."); } +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder) { + llvm::SmallVector element_attrs; + alias.ForEachAlias([&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + std::string kindToString; + switch (alias.kind) { + case HloInputOutputAliasConfig::AliasKind::kMayAlias: + kindToString = "may_alias"; + break; + case HloInputOutputAliasConfig::AliasKind::kMustAlias: + kindToString = "must_alias"; + break; + default: + kindToString = "undefined_alias"; + } + mlir::NamedAttribute alias_named_attributes[3] = { + builder->getNamedAttr( + "parameter_index", + builder->getDenseI64ArrayAttr(ArrayRef( + alias.parameter_index.begin(), alias.parameter_index.end()))), + builder->getNamedAttr("parameter_number", builder->getI64IntegerAttr( + alias.parameter_number)), + builder->getNamedAttr("kind", builder->getStringAttr(kindToString))}; + + mlir::NamedAttribute named_attributes[2] = { + builder->getNamedAttr("output_index", + builder->getDenseI64ArrayAttr(ArrayRef( + output_index.begin(), output_index.end()))), + builder->getNamedAttr( + "alias", builder->getDictionaryAttr(alias_named_attributes))}; + element_attrs.push_back(builder->getDictionaryAttr(named_attributes)); + }); + return builder->getArrayAttr(element_attrs); +} + mlir::Attribute ConvertSharding(const HloSharding& sharding, mlir::Builder* builder) { return builder->getStringAttr(sharding.ToString(/*include_metadata=*/true)); diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/xla/translate/hlo_to_mhlo/hlo_function_importer.h index cb3953990f4030..5c5a4e309bfbf6 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/ValueRange.h" #include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" @@ -297,6 +298,12 @@ class HloFunctionImporter { bool flatten_computation_args_result_; }; +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO C++ input_output_alias_config. +// Always succeeds and returns a non-empty attribute. +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder); + // Returns a StringAttr that carries a prettyprinted representation of the // given HLO C++ sharding. // Always succeeds and returns a non-empty attribute. diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index 1f2ea997c81e8a..76037442d52099 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -122,6 +122,10 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { ConvertSharding(hlo_module.spmd_output_sharding(), &builder_)); } + module->setAttr("mhlo.input_output_alias", + ConvertInputOutputAlias( + hlo_module.input_output_alias_config(), &builder_)); + if (hlo_module.has_spmd_parameters_shardings()) { llvm::SmallVector parameter_shardings; parameter_shardings.reserve(hlo_module.spmd_parameters_shardings().size()); diff --git a/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo b/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo index 74eaaea5a0e8fe..d3433dce372cbf 100644 --- a/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo +++ b/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo @@ -5,6 +5,18 @@ # FLATTEN-CHECK-LABEL: module @main attributes { hlo_module { name: "main" + input_output_alias { + entries { + output_shape_index: 0 + parameter_number: 0 + kind: MAY_ALIAS + } + entries { + output_shape_index: 1 + parameter_number: 1 + kind: MAY_ALIAS + } + } entry_computation_name: "main.5" computations { name: "main.5" @@ -217,6 +229,7 @@ hlo_module { value: "attr_value" } } +# CHECK-SAME: mhlo.input_output_alias = [{alias = {kind = "may_alias", parameter_index = array, parameter_number = 0 : i64}, output_index = array}, {alias = {kind = "may_alias", parameter_index = array, parameter_number = 1 : i64}, output_index = array}] # CHECK-SAME: mhlo.is_dynamic = true is_dynamic: true # CHECK-SAME: mhlo.use_auto_spmd_partitioning = true diff --git a/xla/translate/mhlo_to_hlo/BUILD b/xla/translate/mhlo_to_hlo/BUILD index 92b7265298f6e7..3de8007804af4b 100644 --- a/xla/translate/mhlo_to_hlo/BUILD +++ b/xla/translate/mhlo_to_hlo/BUILD @@ -23,6 +23,7 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/mlir_hlo", "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/xla/translate/mhlo_to_hlo/attribute_exporter.cc index a492861b28d831..73a5c8b994e57e 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "mlir/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" @@ -185,4 +187,99 @@ std::optional ConvertSharding(llvm::StringRef sharding) { return std::nullopt; } +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing) { + if (aliasing.empty()) return std::nullopt; + + xla::HloInputOutputAliasProto input_output_alias_proto; + for (auto attr : aliasing) { + auto entry_attr = mlir::cast(attr); + auto alias_attr = mlir::cast(entry_attr.get("alias")); + mlir::ArrayRef output_index = + mlir::cast(entry_attr.get("output_index")) + .asArrayRef(); + mlir::ArrayRef parameter_index = + mlir::cast(alias_attr.get("parameter_index")) + .asArrayRef(); + HloInputOutputAliasProto::AliasEntryProto entry; + entry.mutable_output_shape_index()->Add(output_index.begin(), + output_index.end()); + entry.set_parameter_number( + mlir::cast(alias_attr.get("parameter_number")) + .getInt()); + entry.mutable_parameter_shape_index()->Add(parameter_index.begin(), + parameter_index.end()); + mlir::StringRef kind = + mlir::cast(alias_attr.get("kind")).getValue(); + if (kind == "may_alias") + entry.set_kind(xla::Kind::MAY_ALIAS); + else if (kind == "must_alias") + entry.set_kind(xla::Kind::MUST_ALIAS); + else + entry.set_kind(xla::Kind::UNDEFINED_ALIAS); + input_output_alias_proto.add_entries()->Swap(&entry); + } + return input_output_alias_proto; +} + +DotDimensionNumbers ConvertDotDimensionNumbers( + mlir::mhlo::DotDimensionNumbersAttr input) { + DotDimensionNumbers output; + + for (auto v : input.getLhsBatchingDimensions()) { + output.add_lhs_batch_dimensions(v); + } + + for (auto v : input.getRhsBatchingDimensions()) { + output.add_rhs_batch_dimensions(v); + } + + for (auto v : input.getLhsContractingDimensions()) { + output.add_lhs_contracting_dimensions(v); + } + + for (auto v : input.getRhsContractingDimensions()) { + output.add_rhs_contracting_dimensions(v); + } + + return output; +} + +DotDimensionNumbers ConvertDotDimensionNumbers( + absl::Span lhs_batch, absl::Span lhs_contract, + absl::Span rhs_batch, + absl::Span rhs_contract) { + DotDimensionNumbers output; + for (auto v : lhs_batch) { + output.add_lhs_batch_dimensions(v); + } + + for (auto v : rhs_batch) { + output.add_rhs_batch_dimensions(v); + } + + for (auto v : lhs_contract) { + output.add_lhs_contracting_dimensions(v); + } + + for (auto v : rhs_contract) { + output.add_rhs_contracting_dimensions(v); + } + + return output; +} + +absl::StatusOr> ConvertMlirArrayAttrToInt64Array( + const mlir::ArrayAttr& array) { + int rank = array.size(); + std::vector converted_array(rank); + for (int i = 0; i < rank; i++) { + mlir::IntegerAttr attr = mlir::dyn_cast(array[i]); + if (!attr) { + return Internal("Type Error: Expected layout integer attribute"); + } + converted_array[i] = attr.getInt(); + } + return converted_array; +} } // namespace xla diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.h b/xla/translate/mhlo_to_hlo/attribute_exporter.h index e0e0dc9821d21e..49daefe6935650 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mlir/IR/Attributes.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" @@ -59,5 +60,8 @@ ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); // Will fail if both attempts at parsing failed. std::optional ConvertSharding(mlir::StringRef sharding); +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing); + } // namespace xla #endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 623080e11fd60d..90eb1a902127bc 100644 --- a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -3736,6 +3736,13 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, *hlo_module.mutable_spmd_output_sharding() = *xla::ConvertSharding(spmd_output_sharding.getValue()); } + if (auto input_output_alias = + module->getAttrOfType("mhlo.input_output_alias")) { + if (std::optional input_output_alias_proto = + xla::ConvertInputOutputAlias(input_output_alias.getValue())) { + *hlo_module.mutable_input_output_alias() = *input_output_alias_proto; + } + } if (auto spmd_parameters_sharding = module->getAttrOfType( "mhlo.spmd_parameters_shardings")) { for (const auto& sharding : spmd_parameters_sharding.getValue()) { diff --git a/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir b/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir index 049456bb09e6f7..6ad08374e5d2e6 100644 --- a/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir +++ b/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir @@ -100,3 +100,45 @@ module @ModuleWithFrontendAttributes attributes { func.return %arg0 : tensor<1xf32> } } + + + +// ----- + +module attributes { +// CHECK: input_output_alias { +// CHECK-NEXT: entries { +// CHECK-NEXT: output_shape_index: 0 +// CHECK-NEXT: kind: MAY_ALIAS +// CHECK-NEXT: } +// CHECK-NEXT: entries { +// CHECK-NEXT: output_shape_index: 1 +// CHECK-NEXT: parameter_number: 1 +// CHECK-NEXT: kind: MAY_ALIAS +// CHECK-NEXT: } +// CHECK-NEXT: } + mhlo.input_output_alias = [ + { + alias = + { + kind = "may_alias", + parameter_index = array, + parameter_number = 0 : i64 + }, + output_index = array + }, + { + alias = + { + kind = "may_alias", + parameter_index = array, + parameter_number = 1 : i64 + }, + output_index = array + } +] +} { + func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32> ) -> (tensor<1xf32>, tensor<1xf32>) { + func.return %arg0, %arg1: tensor<1xf32>, tensor<1xf32> + } +} \ No newline at end of file From 06cb273bc928c99e220622c8e6a50773f109adc5 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Wed, 31 Jul 2024 01:05:00 -0700 Subject: [PATCH 310/376] Remove tesla patch - this is no longer needed as we no longer run triton on T4s as it is unsupported in openai anyway. PiperOrigin-RevId: 657900736 --- third_party/triton/temporary/series.bzl | 1 + third_party/triton/temporary/undo_tesla_gpu.patch | 13 +++++++++++++ 2 files changed, 14 insertions(+) create mode 100644 third_party/triton/temporary/undo_tesla_gpu.patch diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 9d26b42a567757..f55c41ecd38cc1 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -15,5 +15,6 @@ those to this list. temporary_patch_list = [ "//third_party/triton/temporary:cuda11-temporary.patch", + "//third_party/triton/temporary:undo_tesla_gpu.patch", # Add new patches just above this line ] diff --git a/third_party/triton/temporary/undo_tesla_gpu.patch b/third_party/triton/temporary/undo_tesla_gpu.patch new file mode 100644 index 00000000000000..6c2d1d1d734fbc --- /dev/null +++ b/third_party/triton/temporary/undo_tesla_gpu.patch @@ -0,0 +1,13 @@ +This can be removed on the next integrate as it already exists in upstream. +diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +@@ -21,7 +21,7 @@ namespace { + static int getMMAVersionSafe(int computeCapability, DotOp op) { + // List supported mma version in order of preference. + SmallVector versionsSupported; +- if (computeCapability < 80) { ++ if (computeCapability < 75) { + versionsSupported = {1}; + } else if (computeCapability < 90) { + versionsSupported = {2}; From 18336de0c7985fca368f15481316f046e420e77a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 01:11:23 -0700 Subject: [PATCH 311/376] Automated Code Change PiperOrigin-RevId: 657902338 --- xla/service/algebraic_simplifier.cc | 3 +++ xla/service/call_graph.h | 2 +- xla/service/conditional_code_motion.cc | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index fad9dcacaa4ab9..3ee1ffabaee730 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -7359,6 +7359,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { if (multi_output_reduce) { std::vector broadcast_inits; int64_t inputs = reduce->input_count(); + broadcast_inits.reserve(inputs); for (int64_t i = 0; i < inputs; ++i) { broadcast_inits.push_back(reduce->init_values()[i]->AddInstruction( HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i), @@ -7404,6 +7405,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { if (multi_output_reduce) { std::vector reshaped_args; int64_t inputs = reduce->input_count(); + reshaped_args.reserve(inputs); for (int64_t i = 0; i < inputs; ++i) { reshaped_args.push_back( reduce->AddInstruction(HloInstruction::CreateReshape( @@ -7930,6 +7932,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (ShapeUtil::IsZeroElementArray(*input_shapes[0]) || ShapeUtil::IsZeroElementArray(*output_shapes[0])) { std::vector broadcast_inits; + broadcast_inits.reserve(input_count); for (int64_t i = 0; i < input_count; ++i) { broadcast_inits.push_back( hlo->AddInstruction(HloInstruction::CreateBroadcast( diff --git a/xla/service/call_graph.h b/xla/service/call_graph.h index c6f933ef1a250d..0d15a64cafd144 100644 --- a/xla/service/call_graph.h +++ b/xla/service/call_graph.h @@ -141,7 +141,7 @@ class CallGraphNode { CallGraphNode(const CallGraphNode&) = delete; CallGraphNode& operator=(const CallGraphNode&) = delete; CallGraphNode(CallGraphNode&&) = default; - CallGraphNode& operator=(CallGraphNode&&) = default; + CallGraphNode& operator=(CallGraphNode&&) noexcept = default; private: // Only CallGraph can modify CallGraphNode. diff --git a/xla/service/conditional_code_motion.cc b/xla/service/conditional_code_motion.cc index c7ef8e609df40b..cd22c9d4ca0c0a 100644 --- a/xla/service/conditional_code_motion.cc +++ b/xla/service/conditional_code_motion.cc @@ -1005,6 +1005,7 @@ class MoveOperandIntoBranch { CHECK_NE(new_tuple, nullptr); VLOG(5) << "Cloned new tuple:" << new_tuple->parent()->ToString() << "\n"; std::vector> gte_users; + gte_users.reserve(branch_param->shape().tuple_shapes_size()); for (int64_t j = 0; j < branch_param->shape().tuple_shapes_size(); ++j) { gte_users.push_back(std::vector()); } From 01846235e8a1ff06f8af113ca8727732a2a329e9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 01:12:46 -0700 Subject: [PATCH 312/376] Automated Code Change PiperOrigin-RevId: 657902814 --- xla/service/gpu/runtime/nccl_api.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc index 783bd2ddaddee4..f790956d8525af 100644 --- a/xla/service/gpu/runtime/nccl_api.cc +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -451,6 +451,7 @@ absl::StatusOr> DefaultNcclApi::CommSplit( TF_RETURN_IF_ERROR(GroupEnd()); std::vector split_comms; + split_comms.reserve(split_comms_handles.size()); for (size_t i = 0; i < split_comms_handles.size(); ++i) { split_comms.emplace_back(Cast(split_comms_handles[i]), NcclCommDeleter{this}); From 4d967694be2378788a4554a093f043cddc2e110d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 01:14:12 -0700 Subject: [PATCH 313/376] Automated Code Change PiperOrigin-RevId: 657903160 --- xla/hlo/experimental/auto_sharding/auto_sharding.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 273bded2d4d7c5..7fa0c4e56ca27f 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3325,6 +3325,7 @@ void AnnotateShardingWithSimpleHeuristic( if (heuristic == "shard-largest") { std::vector lengths; + lengths.reserve(inst->shape().rank()); for (int64_t i = 0; i < inst->shape().rank(); ++i) { lengths.push_back(inst->shape().dimensions(i)); } From 8bef8364976c7b396f1fa99442306b15025c6d3c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 01:56:19 -0700 Subject: [PATCH 314/376] Automated Code Change PiperOrigin-RevId: 657916008 --- xla/service/transfer_manager.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xla/service/transfer_manager.cc b/xla/service/transfer_manager.cc index b5edfad998ad5e..c601a884919d1b 100644 --- a/xla/service/transfer_manager.cc +++ b/xla/service/transfer_manager.cc @@ -270,6 +270,8 @@ absl::Status TransferManager::WriteRootTupleIndexTable( device_memory.size()); std::vector elements; + elements.reserve( + ShapeUtil::TupleElementCount(device_buffer.on_device_shape())); for (int64_t i = 0; i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) { elements.push_back(device_buffer.buffer({i})); @@ -290,6 +292,7 @@ absl::Status TransferManager::WriteRootTupleIndexTable( device_memory.size()); std::vector elements; + elements.reserve(ShapeUtil::TupleElementCount(buffer_tree.shape())); for (int64_t i = 0; i < ShapeUtil::TupleElementCount(buffer_tree.shape()); ++i) { elements.push_back(buffer_tree.element({i}).AsDeviceMemoryBase()); From e3781e9bbb0aa9295d1a79bd7fcdf226563be51d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 03:04:40 -0700 Subject: [PATCH 315/376] Automated Code Change PiperOrigin-RevId: 657935072 --- xla/client/xla_builder.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/client/xla_builder.cc b/xla/client/xla_builder.cc index c869c43e160518..0b679afce50d96 100644 --- a/xla/client/xla_builder.cc +++ b/xla/client/xla_builder.cc @@ -3854,6 +3854,7 @@ XlaOp XlaBuilder::AllToAllArray( if (is_unbounded) { std::vector new_dimensions; + new_dimensions.reserve(operand_shape->rank()); for (int64_t i = 0; i < operand_shape->rank(); ++i) { new_dimensions.push_back(GetR1DimensionSizeOrConstant(operand, i)); } From a23ee555a04ec096475d61046773b3933f2820d5 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Wed, 31 Jul 2024 03:14:38 -0700 Subject: [PATCH 316/376] PR #15444: Fixed some issues around compiling on Windows. Imported from GitHub PR https://github.com/openxla/xla/pull/15444 This PR fixes some issues I bumped into when trying to compile XLA on Windows. I still haven't gotten GPU support to work but I'm making progress. The CPU only version compiles fine after some of the changes in this PR. I'll point out some specific issues this PR fixes in comments. There are also TSL-specific changes that are pulled in a separate PR (#15499). Copybara import of the project: -- eacee95f41abc49a21516ee389861d84a40eca85 by eaplatanios : Fixed some issues around compiling on Windows. -- b12e4cf0d23c2690111125a651e486ec6a112e54 by eaplatanios : . -- e23ef176de72cf04555242174a19a407884f3f0e by eaplatanios : . -- bdae19b9e15c396985703bb7e88a4db6fcddc7f6 by eaplatanios : . -- 2f90e6ba564f92fafa564b104ed0ce82b7642563 by eaplatanios : . -- 57009793b74c4d7d51fb39547a70a3ec142dadab by eaplatanios : . -- a978b1f7f70d49f1426fe46b107fdcc3618e3085 by eaplatanios : . -- d7fe81dc9cf909a6a8d70e2be8cfffca4063493e by eaplatanios : . -- fc40d919619330bce596555613e425cb6267eea4 by eaplatanios : . -- 326aec3fd73a67ca3c667cfeb5c88a8ffa52eb3d by eaplatanios : . -- a7603b7e1be990ff012440c74bd2c2ecbc2b1e2f by eaplatanios : . -- edcc97a67016584c285d84ac732952c572283119 by eaplatanios : . -- cec244808a8df163f9a803db450ca2bebdda9315 by eaplatanios : . -- df3eb2215eea9076cb352378c5745e113df7cc7d by eaplatanios : . -- 8997345fd1e1aa6f55e445615460124c6e14417c by eaplatanios : . -- 219a9f1bff7fb12c3407ab2e47512560001900fe by eaplatanios : . -- 73f3cd7e0135ec05c97595f795ec318fb635bd32 by eaplatanios : . Merging this change closes #15444 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32 PiperOrigin-RevId: 657937707 --- .../profiler/gpu/cupti_buffer_events.cc | 24 ++++++------ .../profiler/gpu/cupti_buffer_events.h | 2 +- xla/hlo/evaluator/hlo_evaluator.cc | 2 + xla/hlo/evaluator/hlo_evaluator.h | 4 ++ xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 2 +- xla/pjrt/gpu/se_gpu_pjrt_compiler.cc | 14 ++++--- xla/service/cpu/runtime/conv_impl.h | 8 ++-- xla/service/cpu/runtime_conv2d.cc | 6 ++- xla/service/cpu/runtime_conv3d.cc | 6 ++- .../cpu/runtime_single_threaded_conv2d.cc | 6 ++- .../cpu/runtime_single_threaded_conv3d.cc | 6 ++- .../fusions/mlir/computation_partitioner.cc | 12 +++--- xla/service/gpu/kernels/BUILD | 38 ++++++++++++++++--- .../gpu/kernels/cutlass_gemm_adaptor.cu.h | 5 ++- .../gpu/kernels/cutlass_gemm_custom_kernel.cc | 4 ++ .../model/gpu_collective_performance_model.cc | 2 +- .../model/gpu_collective_performance_model.h | 2 + xla/service/gpu/stream_executor_util.cc | 2 +- .../shardy/mhlo_round_trip/mhlo_import.cc | 14 ++++--- xla/stream_executor/cuda/cuda_diagnostics.cc | 2 + xla/stream_executor/cuda/cuda_dnn.cc | 8 ++-- xla/tsl/framework/BUILD | 7 +++- 22 files changed, 117 insertions(+), 59 deletions(-) diff --git a/xla/backends/profiler/gpu/cupti_buffer_events.cc b/xla/backends/profiler/gpu/cupti_buffer_events.cc index ccda1b07902355..376b1809ad4b1a 100644 --- a/xla/backends/profiler/gpu/cupti_buffer_events.cc +++ b/xla/backends/profiler/gpu/cupti_buffer_events.cc @@ -186,18 +186,18 @@ void AddGraphTraceActivityEvent(CuptiEventCollectorDelegate &collector, AnnotationMap::AnnotationInfo info = collector.annotation_map.LookUp( graph_trace->deviceId, graph_trace->correlationId); collector.receive(CuptiTracerEvent{ - .type = CuptiTracerEventType::CudaGraph, - .source = CuptiTracerEventSource::Activity, - .name = absl::StrCat("CudaGraphExec:", graph_trace->graphId), - .annotation = info.annotation, - .nvtx_range = info.nvtx_range, - .start_time_ns = graph_trace->start, - .end_time_ns = graph_trace->end, - .device_id = graph_trace->deviceId, - .correlation_id = graph_trace->correlationId, - .context_id = graph_trace->contextId, - .stream_id = graph_trace->streamId, - .graph_id = graph_trace->graphId, + /* .type = */ CuptiTracerEventType::CudaGraph, + /* .source = */ CuptiTracerEventSource::Activity, + /* .name = */ absl::StrCat("CudaGraphExec:", graph_trace->graphId), + /* .annotation = */ info.annotation, + /* .nvtx_range = */ info.nvtx_range, + /* .start_time_ns = */ graph_trace->start, + /* .end_time_ns = */ graph_trace->end, + /* .device_id = */ graph_trace->deviceId, + /* .correlation_id = */ graph_trace->correlationId, + /* .context_id = */ graph_trace->contextId, + /* .stream_id = */ graph_trace->streamId, + /* .graph_id = */ graph_trace->graphId, }); } diff --git a/xla/backends/profiler/gpu/cupti_buffer_events.h b/xla/backends/profiler/gpu/cupti_buffer_events.h index ac708ed94faeda..f58dda54e623c1 100644 --- a/xla/backends/profiler/gpu/cupti_buffer_events.h +++ b/xla/backends/profiler/gpu/cupti_buffer_events.h @@ -56,7 +56,7 @@ struct MemcpyDetails { int8_t dst_mem_kind; // ID of the hardware channel on which this operation ran. - uint32_t channel_id = -1; + uint32_t channel_id = static_cast(-1); // CUpti_ChannelType of the channel above. int8_t channel_type = 0; // CUPTI_CHANNEL_TYPE_INVALID }; diff --git a/xla/hlo/evaluator/hlo_evaluator.cc b/xla/hlo/evaluator/hlo_evaluator.cc index 9b51dca7721011..761006071dd1f7 100644 --- a/xla/hlo/evaluator/hlo_evaluator.cc +++ b/xla/hlo/evaluator/hlo_evaluator.cc @@ -535,7 +535,9 @@ std::optional EvaluateWhileLoopParamInitValue( namespace internal { +#if !defined(_MSC_VER) constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl"; +#endif std::optional ParseEvalErrorDetail(const absl::Status& error) { auto error_detail = error.GetPayload(kEvalErrorDetailUrl); diff --git a/xla/hlo/evaluator/hlo_evaluator.h b/xla/hlo/evaluator/hlo_evaluator.h index 2f91c39c857c9c..0eab57a0d68de1 100644 --- a/xla/hlo/evaluator/hlo_evaluator.h +++ b/xla/hlo/evaluator/hlo_evaluator.h @@ -530,7 +530,11 @@ enum class EvalErrorDetail : uint32_t { kDynamicValueDependence = 0, }; +#if defined(_MSC_VER) +extern const absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl"; +#else extern const absl::string_view kEvalErrorDetailUrl; +#endif std::optional ParseEvalErrorDetail(const absl::Status& error); diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index e1ba7c832f314d..54b8dbb6514350 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -2129,7 +2129,7 @@ PJRT_Error* PJRT_Layouts_MemoryLayout_Serialize( PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE, args->struct_size)); PJRT_Layouts_SerializedLayout* s_layout = new PJRT_Layouts_SerializedLayout{ - .serialized = args->layout->layout->Serialize()}; + /* .serialized = */ args->layout->layout->Serialize()}; args->serialized_layout = s_layout; args->serialized_bytes = s_layout->serialized.data(); args->serialized_bytes_size = s_layout->serialized.size(); diff --git a/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index 22de6c126af4ab..ea9541ce8a03b1 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -199,13 +199,15 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, #endif } -STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { - PjRtRegisterCompiler( #if TENSORFLOW_USE_ROCM - RocmName(), +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { + PjRtRegisterCompiler(RocmName(), + std::make_unique()); +}); #else - CudaName(), -#endif - std::make_unique()); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { + PjRtRegisterCompiler(CudaName(), + std::make_unique()); }); +#endif } // namespace xla diff --git a/xla/service/cpu/runtime/conv_impl.h b/xla/service/cpu/runtime/conv_impl.h index c6b9747bc0ed51..b97bc85a4edc73 100644 --- a/xla/service/cpu/runtime/conv_impl.h +++ b/xla/service/cpu/runtime/conv_impl.h @@ -41,7 +41,7 @@ void EigenConv2DImpl( Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, - std::optional> done_callback = std::nullopt) { + std::optional> done_callback) { const Eigen::TensorMap, Eigen::Aligned> input(lhs, input_batch, input_x, input_y, input_channels); @@ -129,7 +129,7 @@ void EigenConv3DImpl( Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, Eigen::Index feature_group_count, - std::optional> done_callback = std::nullopt) { + std::optional> done_callback) { using ConstTType = Eigen::TensorMap, Eigen::Aligned>; @@ -223,7 +223,7 @@ void EigenConv3DImpl( Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \ Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \ Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \ - std::optional> done_callback = std::nullopt) + std::optional> done_callback) CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half); CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float); @@ -249,7 +249,7 @@ CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float); Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \ Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \ Eigen::Index feature_group_count, \ - std::optional> done_callback = std::nullopt) + std::optional> done_callback) CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half); CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float); diff --git a/xla/service/cpu/runtime_conv2d.cc b/xla/service/cpu/runtime_conv2d.cc index 907f0f57346020..4bc0d03fe8099e 100644 --- a/xla/service/cpu/runtime_conv2d.cc +++ b/xla/service/cpu/runtime_conv2d.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/cpu/runtime_conv2d.h" +#include + #define EIGEN_USE_THREADS #include "absl/base/dynamic_annotations.h" @@ -41,7 +43,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF32( kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, padding_top, padding_bottom, padding_left, padding_right, lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation, - feature_group_count); + feature_group_count, std::nullopt); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16( @@ -63,5 +65,5 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16( kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, padding_top, padding_bottom, padding_left, padding_right, lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation, - feature_group_count); + feature_group_count, std::nullopt); } diff --git a/xla/service/cpu/runtime_conv3d.cc b/xla/service/cpu/runtime_conv3d.cc index ad86203609e1aa..7e83269e289fdd 100644 --- a/xla/service/cpu/runtime_conv3d.cc +++ b/xla/service/cpu/runtime_conv3d.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/cpu/runtime_conv3d.h" +#include + #define EIGEN_USE_THREADS #include "absl/base/dynamic_annotations.h" @@ -44,7 +46,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF32( y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before, padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation, lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation, - rhs_z_dilation, feature_group_count); + rhs_z_dilation, feature_group_count, std::nullopt); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16( @@ -69,5 +71,5 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16( y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before, padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation, lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation, - rhs_z_dilation, feature_group_count); + rhs_z_dilation, feature_group_count, std::nullopt); } diff --git a/xla/service/cpu/runtime_single_threaded_conv2d.cc b/xla/service/cpu/runtime_single_threaded_conv2d.cc index 999e53cc296025..a770681987400d 100644 --- a/xla/service/cpu/runtime_single_threaded_conv2d.cc +++ b/xla/service/cpu/runtime_single_threaded_conv2d.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/cpu/runtime_single_threaded_conv2d.h" +#include + #include "absl/base/dynamic_annotations.h" #include "xla/service/cpu/runtime/conv_impl.h" @@ -35,7 +37,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF16( kernel_filters, output_rows, output_cols, row_stride, col_stride, padding_top, padding_bottom, padding_left, padding_right, lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation, - feature_group_count); + feature_group_count, std::nullopt); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -55,5 +57,5 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF32( kernel_filters, output_rows, output_cols, row_stride, col_stride, padding_top, padding_bottom, padding_left, padding_right, lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation, - feature_group_count); + feature_group_count, std::nullopt); } diff --git a/xla/service/cpu/runtime_single_threaded_conv3d.cc b/xla/service/cpu/runtime_single_threaded_conv3d.cc index 91dd6c87948712..08ff94d06e7e71 100644 --- a/xla/service/cpu/runtime_single_threaded_conv3d.cc +++ b/xla/service/cpu/runtime_single_threaded_conv3d.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/cpu/runtime_single_threaded_conv3d.h" +#include + #include "absl/base/dynamic_annotations.h" #include "xla/service/cpu/runtime/conv_impl.h" @@ -38,7 +40,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF32( z_stride, padding_x_before, padding_x_after, padding_y_before, padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation, lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation, - rhs_z_dilation, feature_group_count); + rhs_z_dilation, feature_group_count, std::nullopt); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -61,5 +63,5 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF16( z_stride, padding_x_before, padding_x_after, padding_y_before, padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation, lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation, - rhs_z_dilation, feature_group_count); + rhs_z_dilation, feature_group_count, std::nullopt); } diff --git a/xla/service/gpu/fusions/mlir/computation_partitioner.cc b/xla/service/gpu/fusions/mlir/computation_partitioner.cc index c1cc0de31de574..53d8678e953074 100644 --- a/xla/service/gpu/fusions/mlir/computation_partitioner.cc +++ b/xla/service/gpu/fusions/mlir/computation_partitioner.cc @@ -300,12 +300,12 @@ PartitionedComputation::PartitionedComputation( absl::StrJoin(roots, "_", [](std::string* out, const auto* root) { absl::StrAppend(out, root->name()); }))); - subgraphs_.push_back( - Subgraph{.name = std::move(name), - .instructions = {instructions.begin(), instructions.end()}, - .roots = std::move(roots), - .index_ranges = std::move(ranges), - .root_indexing = std::move(root_indexing)}); + subgraphs_.push_back(Subgraph{ + /* .name = */ std::move(name), + /* .instructions = */ {instructions.begin(), instructions.end()}, + /* .roots = */ std::move(roots), + /* .index_ranges = */ std::move(ranges), + /* .root_indexing = */ std::move(root_indexing)}); } for (const auto& subgraph : subgraphs_) { diff --git a/xla/service/gpu/kernels/BUILD b/xla/service/gpu/kernels/BUILD index d4299916ba1b95..ee6192ac7ef4a2 100644 --- a/xla/service/gpu/kernels/BUILD +++ b/xla/service/gpu/kernels/BUILD @@ -9,6 +9,7 @@ load("//xla:xla.bzl", "xla_cc_binary") load("//xla/service/gpu:build_defs.bzl", "gpu_kernel_library") load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") load("//xla/tests:build_defs.bzl", "xla_test") +load("//xla/tsl:tsl.bzl", "if_windows") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -325,7 +326,10 @@ cc_library( cuda_library( name = "cutlass_gemm_adaptor", hdrs = if_cuda_is_configured(["cutlass_gemm_adaptor.cu.h"]), - copts = ["-Wno-unknown-attributes"], # __grid_constant__ is not supported by clang + copts = if_windows( + [], + ["-Wno-unknown-attributes"], + ), # __grid_constant__ is not supported by clang deps = if_cuda_is_configured([ ":cutlass_gemm", "@cutlass_archive//:cutlass", @@ -367,7 +371,13 @@ cc_library( cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16", srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc"]), - copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], + copts = [ + "-mllvm", + "-unroll-threshold=100000", + ] + if_windows( + [], + ["-Wno-unknown-attributes"], + ), deps = if_cuda_is_configured([ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -378,7 +388,13 @@ cuda_library( cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80", srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc"]), - copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], + copts = [ + "-mllvm", + "-unroll-threshold=100000", + ] + if_windows( + [], + ["-Wno-unknown-attributes"], + ), deps = if_cuda_is_configured([ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -389,7 +405,16 @@ cuda_library( cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90", srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc"]), - copts = ["-Wno-ctad-maybe-unsupported -Wno-unknown-attributes -mllvm -unroll-threshold=100000"], + copts = [ + "-mllvm", + "-unroll-threshold=100000", + ] + if_windows( + [], + [ + "-Wno-ctad-maybe-unsupported", + "-Wno-unknown-attributes", + ], + ), deps = if_cuda_is_configured([ ":cutlass_gemm_adaptor", ":cutlass_gemm_epilogue", @@ -401,7 +426,10 @@ cuda_library( cuda_library( name = "cutlass_gemm_kernel_f32xf32_to_f32", srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cu.cc"]), - copts = ["-Wno-unknown-attributes"], + copts = if_windows( + [], + ["-Wno-unknown-attributes"], + ), deps = if_cuda_is_configured([ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", diff --git a/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h b/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h index b8171d615dcfeb..53a6ac6dc6cadf 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h +++ b/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h @@ -199,8 +199,9 @@ namespace adaptor_3x { template static std::optional ClusterDim() { typename Traits::Kernel::DispatchPolicy::ClusterShape cluster; - return Dim3{cute::get<0>(cluster), cute::get<1>(cluster), - cute::get<2>(cluster)}; + return Dim3{static_cast(cute::get<0>(cluster)), + static_cast(cute::get<1>(cluster)), + static_cast(cute::get<2>(cluster))}; } template diff --git a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc index ae39cfbe293d1d..81b2dbd5df7f13 100644 --- a/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc +++ b/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc @@ -101,7 +101,11 @@ KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k, // object constructed in the storage. For now we ignore it, and it's textbook // definition of UB, but for CUTLASS kernels we use today it's perfectly safe. struct Params { +#if defined(_MSC_VER) + alignas(64) std::byte storage[1024]; +#else alignas(128) std::byte storage[1024]; +#endif }; return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed { diff --git a/xla/service/gpu/model/gpu_collective_performance_model.cc b/xla/service/gpu/model/gpu_collective_performance_model.cc index 9459d0c33a7b6a..aad3343260c945 100644 --- a/xla/service/gpu/model/gpu_collective_performance_model.cc +++ b/xla/service/gpu/model/gpu_collective_performance_model.cc @@ -136,7 +136,7 @@ float GpuPerformanceWithCollectiveModel::GetNvlinkBw( } /*static*/ bool GpuPerformanceWithCollectiveModel::InitNvml() { -#if GOOGLE_CUDA +#if GOOGLE_CUDA && (defined(PLATFORM_POSIX) || defined(PLATFORM_GOOGLE)) void* libhandle = dlopen("libnvidia-ml.so.1", RTLD_NOW); CHECK(libhandle != nullptr) << "Failed to open libnvidia-ml.so.1"; diff --git a/xla/service/gpu/model/gpu_collective_performance_model.h b/xla/service/gpu/model/gpu_collective_performance_model.h index 49fe21a2c17919..e1bcff0b5023dd 100644 --- a/xla/service/gpu/model/gpu_collective_performance_model.h +++ b/xla/service/gpu/model/gpu_collective_performance_model.h @@ -26,7 +26,9 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #if GOOGLE_CUDA +#if defined(PLATFORM_POSIX) || defined(PLATFORM_GOOGLE) #include +#endif #include "third_party/gpus/cuda/nvml/include/nvml.h" // Below is a list of function pointers to be used diff --git a/xla/service/gpu/stream_executor_util.cc b/xla/service/gpu/stream_executor_util.cc index 8d4020f859c794..cd64405d0a8ca8 100644 --- a/xla/service/gpu/stream_executor_util.cc +++ b/xla/service/gpu/stream_executor_util.cc @@ -436,7 +436,7 @@ static void InitializeTypedBuffer(se::Stream* stream, // Use a large prime number to fragment the accesses. constexpr int host_buffer_size = 10069; - static std::vector* host_buffer = [] { + static std::vector* host_buffer = [&] { auto* ret = new std::vector(host_buffer_size); // Default-seeded random numbers. std::mt19937 gen; diff --git a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc index f30815c6416927..f72cc4a885c7b3 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc @@ -246,10 +246,10 @@ SmallVector getOrderedSubDimsFromIotaTileAssignment( tileDimIndex--; } subDims.push_back(SubDimInfo{ - .tileDimIndex = tileDimIndex, - .tileSubDimIndex = subDim++, - .reshapeDimIndex = iota.transpose_perm()[transPermIndex], - .size = axisSize, + /* .tileDimIndex = */ tileDimIndex, + /* .tileSubDimIndex = */ subDim++, + /* .reshapeDimIndex = */ iota.transpose_perm()[transPermIndex], + /* .size = */ axisSize, }); accTileSize *= axisSize; accDeviceSize *= axisSize; @@ -296,8 +296,10 @@ AnalyzeTileAssignmentResult analyzeTileAssignment( for (SubDimInfo subDimInfo : subDims) { mesh.push_back(subDimInfo.size); } - return AnalyzeTileAssignmentResult{.subDims = std::move(subDims), - .localMesh = std::move(mesh)}; + return AnalyzeTileAssignmentResult{ + /* .subDims = */ std::move(subDims), + /* .localMesh = */ std::move(mesh), + }; } // Collect shardings with the attr name kXlaShardingAttr in the `moduleOp`. diff --git a/xla/stream_executor/cuda/cuda_diagnostics.cc b/xla/stream_executor/cuda/cuda_diagnostics.cc index 561ac0d401e2f2..2060fb2e296ead 100644 --- a/xla/stream_executor/cuda/cuda_diagnostics.cc +++ b/xla/stream_executor/cuda/cuda_diagnostics.cc @@ -108,6 +108,8 @@ namespace gpu { #if !defined(PLATFORM_WINDOWS) static const char *kDriverVersionPath = "/proc/driver/nvidia/version"; +#else +static const char *kDriverVersionPath = "NO NVIDIA DRIVER VERSION FILE"; #endif // -- class Diagnostician diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index bbc6a6dc2cca79..46a275720d2a91 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -1749,8 +1749,8 @@ absl::Status CheckAndFetchProjectionWeights( int64_t size = dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); #endif // CUDNN_VERSION >= 8100 - dnn::RnnDescriptor::ParamsRegion region = { - reinterpret_cast(offset), size}; + dnn::RnnDescriptor::ParamsRegion region = {static_cast(offset), + size}; weights->push_back(region); } return absl::OkStatus(); @@ -1891,8 +1891,8 @@ absl::StatusOr CudnnRnnParamsDescriptor::Create( /*nbDims=*/&n_dims, /*filterDimA=*/dims)); int64_t size = dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); - dnn::RnnDescriptor::ParamsRegion region = { - reinterpret_cast(offset), size}; + dnn::RnnDescriptor::ParamsRegion region = {static_cast(offset), + size}; (type == 0 ? weights : biases).push_back(region); } #endif // CUDNN_VERSION >= 8100 diff --git a/xla/tsl/framework/BUILD b/xla/tsl/framework/BUILD index 1e6ae7e269bf68..455a38e6785840 100644 --- a/xla/tsl/framework/BUILD +++ b/xla/tsl/framework/BUILD @@ -16,7 +16,7 @@ load( "@tsl//tsl/platform:rules_cc.bzl", "cc_library", ) -load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_windows", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( @@ -358,7 +358,10 @@ cc_library( hdrs = [ "cancellation.h", ], - copts = ["-Wno-thread-safety-precise"], + copts = if_windows( + [], + ["-Wno-thread-safety-precise"], + ), visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/memory", From feb403c2494850af7f73b758f2c031db045125d6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 03:59:16 -0700 Subject: [PATCH 317/376] Automated Code Change PiperOrigin-RevId: 657948234 --- xla/service/gpu/model/indexing_analysis.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/service/gpu/model/indexing_analysis.cc b/xla/service/gpu/model/indexing_analysis.cc index 89124182909aca..8f81cb4d3c33b6 100644 --- a/xla/service/gpu/model/indexing_analysis.cc +++ b/xla/service/gpu/model/indexing_analysis.cc @@ -346,6 +346,7 @@ HloInstructionIndexing ComputeOutputToInputDynamicUpdateSliceOpIndexing( // operand: (d0, ... d_{N-1}) -> (d0, ... d_{N-1}) std::vector identity; + identity.reserve(rank); for (int64_t dim = 0; dim < rank; ++dim) { identity.push_back(getAffineDimExpr(dim, mlir_context)); } From 526ae606c5d87d27bf5e10e7a0c28dd459253986 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Wed, 31 Jul 2024 05:30:05 -0700 Subject: [PATCH 318/376] update Shardy commit hash PiperOrigin-RevId: 657970420 --- third_party/shardy/temporary.patch | 15 --------------- third_party/shardy/workspace.bzl | 4 ++-- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 9bf90881410570..e69de29bb2d1d6 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +0,0 @@ -diff --git i/third_party/llvm/workspace.bzl w/third_party/llvm/workspace.bzl -index 76a13a4..9345d8d 100644 ---- i/third_party/llvm/workspace.bzl -+++ w/third_party/llvm/workspace.bzl -@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") - - def repo(name): - """Imports LLVM.""" -- LLVM_COMMIT = "51681409aeb081c8dfe241e0d8e8c71f8bf0a4f4" -- LLVM_SHA256 = "347cc44fc5bba17b2a6ac26a253803434790a2996b77e8b6fbbeee9b8a367ec8" -+ LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575" -+ LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824" - - tf_http_archive( - name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 3f82df3eee7669..200ac3f5fbd5a3 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "c87ce5b404305927c7a169b305ba0dc1c304e4ce" - SHARDY_SHA256 = "2fa411cfb31f351f2cdad997db0ccb8f9898bad3421f2a78889703bb75bd054c" + SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a" + SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2" tf_http_archive( name = "shardy", From 532d616a2b96373ac24585459916bc2b56136396 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 31 Jul 2024 05:51:02 -0700 Subject: [PATCH 319/376] [XLA:GPU] Move replaceSparseMetaEncoding from Triton to OpenXLA. PiperOrigin-RevId: 657975453 --- .../triton/xla_extensions/sparse_dot.patch | 46 ------------------ .../triton/compilation_pipeline_cuda.cc | 1 + .../gpu/fusions/triton/sparse_extensions.cc | 48 +++++++++++++++++++ .../gpu/fusions/triton/sparse_extensions.h | 1 + ...r => sparse_remove_layout_conversion.mlir} | 2 +- 5 files changed, 51 insertions(+), 47 deletions(-) rename xla/service/gpu/tests/{sparse_ttg_reduce_data_duplication.mlir => sparse_remove_layout_conversion.mlir} (90%) diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch index 21ed97b5afb822..a1c011dbb8beb5 100644 --- a/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/triton/xla_extensions/sparse_dot.patch @@ -344,52 +344,6 @@ index d74e0a224..4e45f7c4c 100644 if (auto dotEnc = dyn_cast( dot.getResult().getType().getEncoding())) { auto loadTy = cast(op->getResultTypes()[0]); -diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp -index 8c1f18e45..c39110d12 100644 ---- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp -@@ -38,6 +38,10 @@ public: - auto srcEncoding = srcType.getEncoding(); - if (isa(srcEncoding)) - return; -+ if (isa(dstType.getEncoding())) { -+ replaceSparseMetaEncoding(cvtOp); -+ return; -+ } - auto dstDotOp = - dyn_cast(dstType.getEncoding()); - if (!dstDotOp) -@@ -86,6 +90,30 @@ public: - cvtOp.erase(); - }); - } -+ -+ private: -+ void replaceSparseMetaEncoding(triton::gpu::ConvertLayoutOp cvtOp) { -+ auto srcType = cast(cvtOp.getOperand().getType()); -+ auto srcEncoding = srcType.getEncoding(); -+ auto sharedLayout = triton::gpu::SharedEncodingAttr::get( -+ cvtOp.getContext(), 8, 1, 1, triton::gpu::getOrder(srcEncoding), -+ triton::gpu::getCTALayout(srcEncoding)); -+ -+ auto dstType = cast(cvtOp.getType()); -+ auto sharedMemorySpace = -+ triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); -+ auto tmpType = triton::MemDescType::get( -+ dstType.getShape(), dstType.getElementType(), sharedLayout, -+ sharedMemorySpace); -+ -+ OpBuilder builder(cvtOp); -+ auto tmp = builder.create( -+ cvtOp.getLoc(), tmpType, cvtOp.getSrc()); -+ auto newConvert = builder.create( -+ cvtOp.getLoc(), dstType, tmp); -+ cvtOp.replaceAllUsesWith(newConvert.getResult()); -+ cvtOp.erase(); -+ } - }; - - } // namespace gpu diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fb0e7f6fd..37795c20c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc index eeac6366bb2c75..471a91c5999e9a 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc @@ -93,6 +93,7 @@ absl::Status CreateTritonPipeline( pm.addPass( mt::gpu::createTritonGPUOptimizeDotOperands({ccCuda.IsAtLeastAmpere()})); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm.addPass(CreateSparseRemoveLayoutConversionPass()); pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication()); pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); pm.addPass(mlir::createCSEPass()); diff --git a/xla/service/gpu/fusions/triton/sparse_extensions.cc b/xla/service/gpu/fusions/triton/sparse_extensions.cc index 8b2f1aba7ee14d..bfc9d7ab2add0a 100644 --- a/xla/service/gpu/fusions/triton/sparse_extensions.cc +++ b/xla/service/gpu/fusions/triton/sparse_extensions.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" @@ -468,6 +469,48 @@ class SparseLocalLoadToLLVM } }; +class SparseRemoveLayoutConversionPass + : public PassWrapper> { + public: + SparseRemoveLayoutConversionPass() = default; + + StringRef getArgument() const override { + return "sparse-remove-layout-conversion"; + } + + void runOnOperation() override { + getOperation().walk([&](triton::gpu::ConvertLayoutOp op) { + ImplicitLocOpBuilder builder(op.getLoc(), op); + auto srcEncoding = + cast(op.getSrc().getType()).getEncoding(); + if (isa(srcEncoding)) { + return; + } + auto dstType = cast(op.getType()); + if (!isa(dstType.getEncoding())) { + return; + } + + auto ctaLayout = triton::gpu::getCTALayout(srcEncoding); + auto sharedLayout = builder.getAttr( + 8, 1, 1, triton::gpu::getOrder(srcEncoding), ctaLayout); + auto sharedMemorySpace = + builder.getAttr(); + auto memType = + triton::MemDescType::get(dstType.getShape(), dstType.getElementType(), + sharedLayout, sharedMemorySpace); + Value alloc = + builder.create(memType, op.getSrc()); + Value convert = builder.create(dstType, alloc); + op.replaceAllUsesWith(convert); + op.erase(); + }); + } + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseRemoveLayoutConversionPass) +}; + class SparseLocalLoadToLLVMPass : public PassWrapper> { public: @@ -1010,6 +1053,10 @@ std::unique_ptr xla::gpu::CreateSparseBlockedToMMAPass() { return std::make_unique(); } +std::unique_ptr xla::gpu::CreateSparseRemoveLayoutConversionPass() { + return std::make_unique(); +} + std::unique_ptr xla::gpu::CreateSparseLocalLoadToLLVMPass() { return std::make_unique(); } @@ -1025,6 +1072,7 @@ std::unique_ptr xla::gpu::CreateSparseWGMMAOpToLLVMPass() { void xla::gpu::RegisterSparsePasses() { registerPass([] { return std::make_unique(); }); registerPass(CreateSparseBlockedToMMAPass); + registerPass(CreateSparseRemoveLayoutConversionPass); registerPass(CreateSparseLocalLoadToLLVMPass); registerPass(CreateSparseDotOpToLLVMPass); registerPass(CreateSparseWGMMAOpToLLVMPass); diff --git a/xla/service/gpu/fusions/triton/sparse_extensions.h b/xla/service/gpu/fusions/triton/sparse_extensions.h index 5d48a4353ae9d6..988a63dfdb00ac 100644 --- a/xla/service/gpu/fusions/triton/sparse_extensions.h +++ b/xla/service/gpu/fusions/triton/sparse_extensions.h @@ -26,6 +26,7 @@ namespace xla::gpu { std::unique_ptr CreateAddSparseDotEncodingPass( int32_t num_warps, int32_t threads_per_warp, int32_t num_ctas); std::unique_ptr CreateSparseBlockedToMMAPass(); +std::unique_ptr CreateSparseRemoveLayoutConversionPass(); std::unique_ptr CreateSparseLocalLoadToLLVMPass(); std::unique_ptr CreateSparseDotOpToLLVMPass(); std::unique_ptr CreateSparseWGMMAOpToLLVMPass(); diff --git a/xla/service/gpu/tests/sparse_ttg_reduce_data_duplication.mlir b/xla/service/gpu/tests/sparse_remove_layout_conversion.mlir similarity index 90% rename from xla/service/gpu/tests/sparse_ttg_reduce_data_duplication.mlir rename to xla/service/gpu/tests/sparse_remove_layout_conversion.mlir index 5604a1ac5baf46..7db3874eef4047 100644 --- a/xla/service/gpu/tests/sparse_ttg_reduce_data_duplication.mlir +++ b/xla/service/gpu/tests/sparse_remove_layout_conversion.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s +// RUN: xla-opt %s --sparse-remove-layout-conversion | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> From 1eb49d036730d84d72661b389e8fb98d5f416cae Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Wed, 31 Jul 2024 06:13:36 -0700 Subject: [PATCH 320/376] [XLA:GPU] Make "DumpingWorks" test smaller. PiperOrigin-RevId: 657981063 --- xla/service/gpu/gemm_fusion_autotuner_test.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xla/service/gpu/gemm_fusion_autotuner_test.cc b/xla/service/gpu/gemm_fusion_autotuner_test.cc index 32c2c96d595a44..b0d8ba6a88691b 100644 --- a/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -654,18 +654,18 @@ TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( fusion1 { - p0 = f32[3333,3333] parameter(0) - s = f32[3333,3333] sine(p0) - p1 = f32[3333,3333] parameter(1) - c = f32[3333,3333] cosine(p1) - ROOT dot = f32[3333,3333] dot(s, c), + p0 = f32[333,333] parameter(0) + s = f32[333,333] sine(p0) + p1 = f32[333,333] parameter(1) + c = f32[333,333] cosine(p1) + ROOT dot = f32[333,333] dot(s, c), lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY e { - p0 = f32[3333,3333] parameter(0) - p1 = f32[3333,3333] parameter(1) - ROOT rr = f32[3333,3333] fusion(p0, p1), kind=kCustom, calls=fusion1, + p0 = f32[333,333] parameter(0) + p1 = f32[333,333] parameter(1) + ROOT rr = f32[333,333] fusion(p0, p1), kind=kCustom, calls=fusion1, backend_config={"fusion_backend_config": {kind: "__triton_gemm"}} })", config)); From dff0898d2b101a60b611be6ba9302d30dffc974c Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Wed, 31 Jul 2024 07:01:30 -0700 Subject: [PATCH 321/376] Integrate LLVM at llvm/llvm-project@42d641ef5cc4 Updates LLVM usage to match [42d641ef5cc4](https://github.com/llvm/llvm-project/commit/42d641ef5cc4) PiperOrigin-RevId: 657992691 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/temporary.patch | 15 +++++++++++++++ third_party/tsl/third_party/llvm/workspace.bzl | 4 ++-- .../tools/mlir_interpreter/dialects/vector.cc | 2 +- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 9345d8db8d67ef..6429d9bd82a98c 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575" - LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824" + LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d" + LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index e69de29bb2d1d6..4d99610ad94bd8 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -0,0 +1,15 @@ +diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl +index 9345d8d..6429d9b 100644 +--- a/third_party/llvm/workspace.bzl ++++ b/third_party/llvm/workspace.bzl +@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") + + def repo(name): + """Imports LLVM.""" +- LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575" +- LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824" ++ LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d" ++ LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd" + + tf_http_archive( + name = name, diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 9345d8db8d67ef..6429d9bd82a98c 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575" - LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824" + LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d" + LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd" tf_http_archive( name = name, diff --git a/xla/mlir/tools/mlir_interpreter/dialects/vector.cc b/xla/mlir/tools/mlir_interpreter/dialects/vector.cc index 7aaaf5af97215e..b0223f3e6ed532 100644 --- a/xla/mlir/tools/mlir_interpreter/dialects/vector.cc +++ b/xla/mlir/tools/mlir_interpreter/dialects/vector.cc @@ -634,7 +634,7 @@ InterpreterValue Shuffle(InterpreterState& state, vector::ShuffleOp shuffle, auto& result_view = result.View(); result_view.is_vector = true; - auto mask = ExtractVector(shuffle.getMask()); + auto mask = shuffle.getMask(); bool is_zero_dim = v0.View().Rank() == 0; int64_t size0 = is_zero_dim ? 1 : v0.View().sizes[0]; for (auto [dst_index, src_index] : llvm::enumerate(mask)) { From b8021a5935d3b2629673f0367db54c36b6fdc6b4 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 31 Jul 2024 07:20:24 -0700 Subject: [PATCH 322/376] [xla:cpu] Optimize Thunk::OkExecuteEvent MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_SelectAndScatterF32/128/process_time 385µs ± 2% 378µs ± 4% -1.82% BM_SelectAndScatterF32/256/process_time 1.58ms ± 2% 1.56ms ± 2% -1.77% BM_SelectAndScatterF32/512/process_time 7.24ms ± 4% 7.07ms ± 6% -2.39% PiperOrigin-RevId: 657997714 --- xla/service/cpu/runtime/thunk.cc | 10 ++++++--- xla/service/cpu/runtime/thunk.h | 26 ++++++++++++----------- xla/service/cpu/runtime/thunk_executor.cc | 10 ++++----- xla/service/cpu/runtime/thunk_test.cc | 4 ++-- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/xla/service/cpu/runtime/thunk.cc b/xla/service/cpu/runtime/thunk.cc index 5438e60b33d844..9588b02a61a4df 100644 --- a/xla/service/cpu/runtime/thunk.cc +++ b/xla/service/cpu/runtime/thunk.cc @@ -85,6 +85,10 @@ std::string_view Thunk::KindToString(Kind kind) { return "while"; } } +Thunk::Thunk(Kind kind, Info info) + : kind_(kind), + info_(std::move(info)), + ok_event_(OkExecuteEventSingleton()) {} absl::StatusOr Thunk::CollectiveExecuteParams::Create( @@ -150,13 +154,13 @@ Thunk::CustomCallExecuteParams::CustomCallExecuteParams( allocator(allocator), ffi_execution_context(ffi_execution_context) {} -const tsl::AsyncValueOwningRef* Thunk::OkEvent() { - static tsl::AsyncValueOwningRef* owner = [] { +tsl::AsyncValueRef Thunk::OkExecuteEventSingleton() { + static tsl::AsyncValueOwningRef* singleton = [] { auto* storage = new tsl::internal::AsyncValueStorage(); return new tsl::AsyncValueOwningRef( tsl::MakeAvailableAsyncValueRef(*storage)); }(); - return owner; + return singleton->AsRef(); } Thunk::ExecuteState::ExecuteState(int64_t num_tasks) diff --git a/xla/service/cpu/runtime/thunk.h b/xla/service/cpu/runtime/thunk.h index 5bf8cfb8baf01d..0e645f247776c5 100644 --- a/xla/service/cpu/runtime/thunk.h +++ b/xla/service/cpu/runtime/thunk.h @@ -110,7 +110,7 @@ class Thunk { using Task = std::function; using TaskRunner = absl::AnyInvocable; - Thunk(Kind kind, Info info) : kind_(kind), info_(std::move(info)) {} + Thunk(Kind kind, Info info); Thunk(const Thunk&) = delete; Thunk& operator=(const Thunk&) = delete; @@ -286,18 +286,20 @@ class Thunk { // An execute event that becomes ready when all tasks are completed. using ExecuteEvent = tsl::Chain; - // Returns non-reference-counted async value ref for thunks executed in the - // caller thread to avoid reference counting overhead. - static tsl::AsyncValueRef OkExecuteEvent() { - return OkEvent()->AsRef(); - } + // Returns non-reference-counted async value ref in constructed state. + // Returned async value is a per-process singleton stored in a storage with a + // static duration, and can be safely compared using pointer equality. + static tsl::AsyncValueRef OkExecuteEventSingleton(); + + // Returns `OkExecuteEventSingleton()` cached by this thunk instance. + tsl::AsyncValueRef OkExecuteEvent() const { return ok_event_; } - static bool IsOkExecuteEvent(tsl::AsyncValuePtr event) { - return event == OkEvent()->AsPtr(); + bool IsOkExecuteEvent(const tsl::AsyncValueRef& event) const { + return event == ok_event_; } - static bool IsOkExecuteEvent(const tsl::AsyncValueRef& event) { - return IsOkExecuteEvent(event.AsPtr()); + bool IsOkExecuteEvent(tsl::AsyncValuePtr event) const { + return event == ok_event_.AsPtr(); } // Thunk execution must be asynchronous and never block the caller thread, @@ -339,10 +341,10 @@ class Thunk { } private: - static const tsl::AsyncValueOwningRef* OkEvent(); - Kind kind_; Info info_; + + tsl::AsyncValueRef ok_event_; }; std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index 26c084e7e8c5e4..56b6d405336d71 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -144,7 +144,7 @@ tsl::AsyncValueRef ThunkExecutor::Execute( const Thunk::ExecuteParams& params) { // Short-circuit execution of trivial thunk sequences. if (ABSL_PREDICT_FALSE(thunk_sequence_.empty())) { - return Thunk::OkExecuteEvent(); + return Thunk::OkExecuteEventSingleton(); } if (ABSL_PREDICT_FALSE(thunk_sequence_.size() == 1)) { return thunk_sequence_[0]->Execute(params); @@ -181,7 +181,7 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { auto execute_event = thunk.Execute(params); // Fast path for thunks executed inline and returned OkExecuteEvent. - if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) { + if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) { continue; } @@ -207,7 +207,7 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { // If we got to the end of the sequence it means that all thunks have // succeeded. - return Thunk::OkExecuteEvent(); + return Thunk::OkExecuteEventSingleton(); } void ThunkExecutor::ResumeExecuteSequential( @@ -218,7 +218,7 @@ void ThunkExecutor::ResumeExecuteSequential( auto execute_event = thunk.Execute(params); // Fast path for thunks executed inline and returned OkExecuteEvent. - if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) { + if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) { continue; } @@ -281,7 +281,7 @@ void ThunkExecutor::Execute(ExecuteState* state, Thunk& thunk = *state->executor->thunk_sequence_[id]; tsl::AsyncValueRef execute_event = ABSL_PREDICT_FALSE(state->abort.load(std::memory_order_relaxed)) - ? Thunk::OkExecuteEvent() + ? Thunk::OkExecuteEventSingleton() : thunk.Execute(params); if (ABSL_PREDICT_TRUE(execute_event.IsAvailable())) { diff --git a/xla/service/cpu/runtime/thunk_test.cc b/xla/service/cpu/runtime/thunk_test.cc index 510d2c2f44025a..3b975750be6f1d 100644 --- a/xla/service/cpu/runtime/thunk_test.cc +++ b/xla/service/cpu/runtime/thunk_test.cc @@ -34,8 +34,8 @@ class ThunkExecuteStateTestHelper : public Thunk { } }; -TEST(ThunkTest, OkExecuteEvent) { - auto event = Thunk::OkExecuteEvent(); +TEST(ThunkTest, OkExecuteEventSingleton) { + auto event = Thunk::OkExecuteEventSingleton(); ASSERT_TRUE(event.IsConcrete()); } From dc35dc52fe2614334de21fdf765a4313f00bbbe1 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 31 Jul 2024 08:30:20 -0700 Subject: [PATCH 323/376] [xla:cpu] Use iterators for executing thunks sequentially MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This saves one register and a few instructions in the hot loop. name old time/op new time/op delta BM_SelectAndScatterF32/128/process_time 377µs ± 4% 371µs ± 2% -1.73% BM_SelectAndScatterF32/256/process_time 1.55ms ± 4% 1.52ms ± 2% -1.98% BM_SelectAndScatterF32/512/process_time 6.64ms ± 4% 6.58ms ± 4% -0.93% PiperOrigin-RevId: 658017389 --- xla/service/cpu/runtime/thunk_executor.cc | 31 ++++++++++++----------- xla/service/cpu/runtime/thunk_executor.h | 5 +++- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index 56b6d405336d71..155edd9bdbca66 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -45,6 +45,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, const ThunkExecutor::Options& options) : thunk_sequence_(std::move(thunk_sequence)), options_(options), + num_thunks_(thunk_sequence_.size()), nodes_defs_(std::move(nodes_defs)), is_sequential_(true) { for (NodeId i = 0; i < nodes_defs_.size(); ++i) { @@ -143,10 +144,10 @@ ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor, tsl::AsyncValueRef ThunkExecutor::Execute( const Thunk::ExecuteParams& params) { // Short-circuit execution of trivial thunk sequences. - if (ABSL_PREDICT_FALSE(thunk_sequence_.empty())) { + if (ABSL_PREDICT_FALSE(num_thunks_ == 0)) { return Thunk::OkExecuteEventSingleton(); } - if (ABSL_PREDICT_FALSE(thunk_sequence_.size() == 1)) { + if (ABSL_PREDICT_FALSE(num_thunks_ == 1)) { return thunk_sequence_[0]->Execute(params); } @@ -176,8 +177,8 @@ tsl::AsyncValueRef ThunkExecutor::Execute( tsl::AsyncValueRef ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { - for (int64_t i = 0; i < thunk_sequence_.size(); ++i) { - Thunk& thunk = *thunk_sequence_[i]; + for (auto it = thunk_sequence_.begin(); it != thunk_sequence_.end(); ++it) { + Thunk& thunk = **it; auto execute_event = thunk.Execute(params); // Fast path for thunks executed inline and returned OkExecuteEvent. @@ -189,11 +190,11 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { // resume sequential execution starting from the next thunk. if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { auto event = tsl::MakeConstructedAsyncValueRef(); - execute_event.AndThen([this, ¶ms, i, event](absl::Status status) { + execute_event.AndThen([this, ¶ms, it, event](absl::Status status) { if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); } else { - ResumeExecuteSequential(i + 1, params, std::move(event)); + ResumeExecuteSequential(it + 1, params, std::move(event)); } }); return event; @@ -211,10 +212,10 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { } void ThunkExecutor::ResumeExecuteSequential( - int64_t index, const Thunk::ExecuteParams& params, + ThunkIterator it, const Thunk::ExecuteParams& params, tsl::AsyncValueRef event) { - for (int64_t i = index; i < thunk_sequence_.size(); ++i) { - Thunk& thunk = *thunk_sequence_[i]; + for (; it != thunk_sequence_.end(); ++it) { + Thunk& thunk = **it; auto execute_event = thunk.Execute(params); // Fast path for thunks executed inline and returned OkExecuteEvent. @@ -226,11 +227,11 @@ void ThunkExecutor::ResumeExecuteSequential( // resume sequential execution starting from the next thunk. if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { execute_event.AndThen( - [this, ¶ms, i, event = std::move(event)](absl::Status status) { + [this, ¶ms, it, event = std::move(event)](absl::Status status) { if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); } else { - ResumeExecuteSequential(i + 1, params, std::move(event)); + ResumeExecuteSequential(it + 1, params, std::move(event)); } }); return; @@ -471,11 +472,11 @@ int64_t ThunkExecutor::TransitiveReduction() { std::string ThunkExecutor::ToString() const { std::string str = absl::StrFormat( - "ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d", - thunk_sequence_.size(), source_.size(), sink_.size()); + "ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d", num_thunks_, + source_.size(), sink_.size()); // Collect names of `in_edges`. - std::vector> in_edges(thunk_sequence_.size()); + std::vector> in_edges(num_thunks_); for (const auto& node_def : nodes_defs_) { for (NodeId in_edge : node_def.in_edges) { in_edges[node_def.id].push_back(thunk_sequence_[in_edge]->info().op_name); @@ -483,7 +484,7 @@ std::string ThunkExecutor::ToString() const { } // Print thunks with a list of their dependencies; - for (NodeId i = 0; i < thunk_sequence_.size(); ++i) { + for (NodeId i = 0; i < num_thunks_; ++i) { const Thunk& thunk = *thunk_sequence_[i]; bool is_source = absl::c_find(source_, i) != source_.end(); bool is_sink = absl::c_find(sink_, i) != sink_.end(); diff --git a/xla/service/cpu/runtime/thunk_executor.h b/xla/service/cpu/runtime/thunk_executor.h index 8965a7a51652a4..67a66c422bf5c6 100644 --- a/xla/service/cpu/runtime/thunk_executor.h +++ b/xla/service/cpu/runtime/thunk_executor.h @@ -144,7 +144,8 @@ class ThunkExecutor { const Thunk::ExecuteParams& params); // Resumes sequential thunk execution starting from the given index. - void ResumeExecuteSequential(int64_t index, + using ThunkIterator = typename ThunkSequence::iterator; + void ResumeExecuteSequential(ThunkIterator it, const Thunk::ExecuteParams& params, tsl::AsyncValueRef event); @@ -173,6 +174,8 @@ class ThunkExecutor { ThunkSequence thunk_sequence_; Options options_; + int64_t num_thunks_; + std::vector nodes_defs_; std::vector source_; From 77707a8d5517afce2686b5dffdcdd30706e1f661 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Wed, 31 Jul 2024 08:32:36 -0700 Subject: [PATCH 324/376] [NFC] Fix a few common typos. PiperOrigin-RevId: 658017967 --- xla/debug_options_flags.cc | 8 ++++---- xla/xla.proto | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index a9d316399f78fe..28ab8fa6e390c5 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -1107,7 +1107,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, collective_op_types_to_string( debug_options->xla_gpu_disable_async_collectives()), "This disables a certain set of async collectives and turn them into" - " synchornous ones. By default, this is empty which indicates enabling" + " synchronous ones. By default, this is empty which indicates enabling" " async execution for all collectives. A sample usage is: " " --xla_gpu_disable_async_collectives=ALLREDUCE,REDUCESCATTER")); flag_list->push_back(tsl::Flag( @@ -1204,7 +1204,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_cudnn_fmha), debug_options->xla_gpu_enable_cudnn_fmha(), "Use the cuDNN Fused Attention runtime fusion when possible. Note " - "that dropout support and the developement of this feature as a whole is " + "that dropout support and the development of this feature as a whole is " "in progress. Attention with dropout may cause results to diverge with " "and without this flag turned on.")); flag_list->push_back(tsl::Flag( @@ -1240,7 +1240,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, setter_for_legacy_command_buffer_custom_call_targets, "", "Comma-separated list of custom call targets with legacy " "registry API (non FFI API), whose targets supports lowering " - "to command buffer custom command, i.e, custom call target " + "to command buffer custom command, i.e., custom call target " "supports cuda-graph capturing for CUDA devices.")); flag_list->push_back(tsl::Flag( @@ -1765,7 +1765,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int64_setter_for(&DebugOptions::set_xla_gpu_gemm_rewrite_size_threshold), debug_options->xla_gpu_gemm_rewrite_size_threshold(), "Threshold until which elemental dot emitter is preferred for GEMMs " - "(minumum combined number of elements of both matrices " + "(minimum combined number of elements of both matrices " "in non-batch dimensions to be considered for a rewrite).")); flag_list->push_back(tsl::Flag( "xla_gpu_use_memcpy_local_p2p", diff --git a/xla/xla.proto b/xla/xla.proto index 50993c6c411ea4..b43944c6ac0369 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -749,7 +749,7 @@ message DebugOptions { // are counted. reserved 282; // was xla_gpu_skip_mlir_kernels - // Threshold to rewrite matmul to cuBLAS or Triton (minumum combined number of + // Threshold to rewrite matmul to cuBLAS or Triton (minimum combined number of // elements of both matrices in non-batch dimensions to be considered for a // rewrite). int64 xla_gpu_gemm_rewrite_size_threshold = 283; @@ -835,7 +835,7 @@ message DebugOptions { // Custom call targets with legacy registry API (non FFI API), // that support recording to command buffer custom command, - // i.e, custom call target supports cuda-graph capturing for CUDA devices. + // i.e., custom call target supports cuda-graph capturing for CUDA devices. // This flag is read if CUSTOM_CALL command type is recorded into // command buffer. repeated string legacy_command_buffer_custom_call_targets = 314; From 77b91ce38136b593d85cf3be97e44bfdc6d2fd6d Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 31 Jul 2024 08:46:57 -0700 Subject: [PATCH 325/376] Move UnloadKernel from StreamExecutor to GpuExecutor, as it's not used anywhere else. PiperOrigin-RevId: 658022450 --- xla/stream_executor/gpu/gpu_executor.h | 4 ++-- xla/stream_executor/mock_stream_executor.h | 1 - xla/stream_executor/stream_executor.h | 3 --- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index e7926f01e0349f..7e7b834ebc154f 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -128,8 +128,8 @@ class GpuExecutor : public StreamExecutorCommon { absl::StatusOr> LoadKernel( const MultiKernelLoaderSpec& spec) override; - // (supported on CUDA only) - void UnloadKernel(const Kernel* kernel) override; + // Releases any state associated with the previously loaded kernel. + void UnloadKernel(const Kernel* kernel); absl::Status LoadModule(const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle) override; bool UnloadModule(ModuleHandle module_handle) override; diff --git a/xla/stream_executor/mock_stream_executor.h b/xla/stream_executor/mock_stream_executor.h index 2655d48833b0ec..d0d696cfd1147a 100644 --- a/xla/stream_executor/mock_stream_executor.h +++ b/xla/stream_executor/mock_stream_executor.h @@ -78,7 +78,6 @@ class MockStreamExecutor : public StreamExecutor { const BlockDim& block_dims, const ClusterDim& cluster_dims, const Kernel& k, const KernelArgs& args), (override)); - MOCK_METHOD(void, UnloadKernel, (const Kernel* kernel), (override)); MOCK_METHOD(DeviceMemoryBase, Allocate, (uint64_t size, int64_t memory_space), (override)); MOCK_METHOD(void, Deallocate, (DeviceMemoryBase * mem), (override)); diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index d4f264a3de961d..b4d83efa8ce2b8 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -158,9 +158,6 @@ class StreamExecutor { return absl::UnimplementedError("Not Implemented"); } - // Releases any state associated with the previously loaded kernel. - virtual void UnloadKernel(const Kernel* kernel) {} - // Synchronously allocates size bytes on the underlying platform and returns // a DeviceMemoryBase representing that allocation. In the case of failure, // nullptr is returned. From 930596cac514b86f1c6cd1c46434293117bb3ec5 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 31 Jul 2024 09:32:25 -0700 Subject: [PATCH 326/376] Replace GpuKernel::AsGpuFunctionHandle and ::gpu_function_ptr with Google-style accessor methods. PiperOrigin-RevId: 658036473 --- xla/stream_executor/cuda/cuda_executor.cc | 16 +++++++------- xla/stream_executor/gpu/gpu_command_buffer.cc | 4 ++-- xla/stream_executor/gpu/gpu_kernel.h | 21 ++++++++----------- xla/stream_executor/rocm/rocm_executor.cc | 19 +++++++++-------- 4 files changed, 29 insertions(+), 31 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index a8118b27b83c98..3c268909b461dd 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -273,7 +273,7 @@ absl::StatusOr> GpuExecutor::LoadKernel( TF_ASSIGN_OR_RETURN( GpuFunctionHandle function, GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); - *cuda_kernel->gpu_function_ptr() = function; + cuda_kernel->set_gpu_function(function); } else { return absl::InternalError("No method of loading CUDA kernel provided"); @@ -283,14 +283,14 @@ absl::StatusOr> GpuExecutor::LoadKernel( // from a module, as CUDA runtime did that automatically for us. if (!spec.has_in_process_symbol()) { VLOG(2) << "getting function " << *kernel_name << " from module " << module; - TF_RETURN_IF_ERROR( - GpuDriver::GetModuleFunction(context_, module, kernel_name->c_str(), - cuda_kernel->gpu_function_ptr())); + GpuFunctionHandle function; + TF_RETURN_IF_ERROR(GpuDriver::GetModuleFunction( + context_, module, kernel_name->c_str(), &function)); + cuda_kernel->set_gpu_function(function); } // Update CUDA kernel properties after it was loaded in the CUDA context. cuda_kernel->set_name(*kernel_name); - cuda_kernel->set_gpu_context(context_); // We have to trust the kernel loader spec arity because there doesn't appear // to be a way to reflect on the number of expected arguments w/the CUDA API. @@ -482,12 +482,12 @@ absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, KernelMetadata* kernel_metadata) { int value; TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute( - CU_FUNC_ATTRIBUTE_NUM_REGS, *cuda_kernel->gpu_function_ptr(), &value)); + CU_FUNC_ATTRIBUTE_NUM_REGS, cuda_kernel->gpu_function(), &value)); kernel_metadata->set_registers_per_thread(value); TF_RETURN_IF_ERROR( GpuDriver::FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - *cuda_kernel->gpu_function_ptr(), &value)); + cuda_kernel->gpu_function(), &value)); kernel_metadata->set_shared_memory_bytes(value); return absl::OkStatus(); } @@ -512,7 +512,7 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, const Kernel& kernel, const KernelArgs& args) { CUstream custream = AsGpuStreamValue(stream); const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); - CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); + CUfunction cufunc = cuda_kernel->gpu_function(); if (cuda_kernel->cache_config() != KernelCacheConfig::kNoPreference) { TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( diff --git a/xla/stream_executor/gpu/gpu_command_buffer.cc b/xla/stream_executor/gpu/gpu_command_buffer.cc index 2fdc8118bcd787..d5d9ebd1d7290f 100644 --- a/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -340,7 +340,7 @@ absl::StatusOr GpuCommandBuffer::CreateBarrierNode( TF_RETURN_IF_ERROR(GpuDriver::GraphAddKernelNode( &barrier_handle, graph_, dependencies, "noop", - AsGpuKernel(&**noop)->AsGpuFunctionHandle(), 1, 1, 1, 1, 1, 1, 0, + AsGpuKernel(&**noop)->gpu_function(), 1, 1, 1, 1, 1, 1, 0, /*kernel_params=*/nullptr, /*extra=*/nullptr)); #else TF_RETURN_IF_ERROR( @@ -524,7 +524,7 @@ absl::Status GpuCommandBuffer::LaunchWithPackedArgs( packed_args.number_of_arguments()); const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); - GpuFunctionHandle gpu_func = gpu_kernel->AsGpuFunctionHandle(); + GpuFunctionHandle gpu_func = gpu_kernel->gpu_function(); void** kernel_params = const_cast(packed_args.argument_addresses().data()); diff --git a/xla/stream_executor/gpu/gpu_kernel.h b/xla/stream_executor/gpu/gpu_kernel.h index ea027f4dac22fb..d17b974fe44b7a 100644 --- a/xla/stream_executor/gpu/gpu_kernel.h +++ b/xla/stream_executor/gpu/gpu_kernel.h @@ -39,7 +39,9 @@ namespace stream_executor::gpu { class GpuKernel : public Kernel { public: - explicit GpuKernel(GpuExecutor* gpu_executor) : gpu_executor_(gpu_executor) {} + explicit GpuKernel(GpuExecutor* gpu_executor) + : gpu_executor_(gpu_executor), + gpu_context_(gpu_executor->gpu_context()) {} // Note that the function is unloaded when the module is unloaded, and the // module that the function is contained in is owned by the GpuExecutor. @@ -51,17 +53,6 @@ class GpuKernel : public Kernel { unsigned Arity() const override { return arity_; } void set_name(std::string name) { name_ = std::move(name); } - void set_gpu_context(GpuContext* gpu_context) { gpu_context_ = gpu_context; } - - // Returns the GpuFunctionHandle value for passing to the CUDA API. - GpuFunctionHandle AsGpuFunctionHandle() const { - DCHECK(gpu_function_ != nullptr); - return const_cast(gpu_function_); - } - - // Returns the slot that the GpuFunctionHandle is stored within for this - // object, for the CUDA API which wants to load into a GpuFunctionHandle*. - GpuFunctionHandle* gpu_function_ptr() { return &gpu_function_; } // Returns the current kernel cache configuration preference as a // GpuFuncCachePreference. @@ -70,6 +61,12 @@ class GpuKernel : public Kernel { absl::StatusOr GetMaxOccupiedBlocksPerCore( ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; + // Simple accessor methods. + GpuFunctionHandle gpu_function() const { return gpu_function_; } + void set_gpu_function(GpuFunctionHandle gpu_function) { + gpu_function_ = gpu_function; + } + private: GpuExecutor* gpu_executor_ = nullptr; GpuContext* gpu_context_ = nullptr; // context where kernel is loaded diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index d879fbbbb0aae4..cb207eabbed6e1 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -284,10 +284,10 @@ absl::StatusOr> GpuExecutor::LoadKernel( TF_ASSIGN_OR_RETURN( GpuFunctionHandle function, GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); - *rocm_kernel->gpu_function_ptr() = function; + rocm_kernel->set_gpu_function(function); #else - *rocm_kernel->gpu_function_ptr() = - static_cast(spec.in_process_symbol().symbol()); + rocm_kernel->set_gpu_function( + static_cast(spec.in_process_symbol().symbol())); #endif // TF_ROCM_VERSION >= 60200 } else { @@ -298,9 +298,10 @@ absl::StatusOr> GpuExecutor::LoadKernel( // from a module, as ROCm runtime did that automatically for us. if (!spec.has_in_process_symbol()) { VLOG(2) << "getting function " << *kernel_name << " from module " << module; - TF_RETURN_IF_ERROR( - GpuDriver::GetModuleFunction(context_, module, kernel_name->c_str(), - rocm_kernel->gpu_function_ptr())); + GpuFunctionHandle function; + TF_RETURN_IF_ERROR(GpuDriver::GetModuleFunction( + context_, module, kernel_name->c_str(), &function)); + rocm_kernel->set_gpu_function(function); } // We have to trust the kernel loader spec arity because there doesn't appear @@ -322,12 +323,12 @@ absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, KernelMetadata* kernel_metadata) { int value = 0; TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute( - HIP_FUNC_ATTRIBUTE_NUM_REGS, *rocm_kernel->gpu_function_ptr(), &value)); + HIP_FUNC_ATTRIBUTE_NUM_REGS, rocm_kernel->gpu_function(), &value)); kernel_metadata->set_registers_per_thread(value); TF_RETURN_IF_ERROR( GpuDriver::FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - *rocm_kernel->gpu_function_ptr(), &value)); + rocm_kernel->gpu_function(), &value)); kernel_metadata->set_shared_memory_bytes(value); return absl::OkStatus(); } @@ -337,7 +338,7 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, const Kernel& kernel, const KernelArgs& args) { GpuStreamHandle hipstream = AsGpuStreamValue(stream); const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); - hipFunction_t hipfunc = rocm_kernel->AsGpuFunctionHandle(); + hipFunction_t hipfunc = rocm_kernel->gpu_function(); if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) { TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( From 84a4923b1d1d5cc29bb8f01820356cc20edd6d16 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Wed, 31 Jul 2024 09:45:13 -0700 Subject: [PATCH 327/376] Hide the SDY dialect right before MLIR->HLO conversion in the XLA pipeline. Since Shardy is inside the middle of the XLA pipeline, after converting down to HLO, we need to run the Shardy export pipeline to preserve the SDY ops and sharding attributes for when we come back from HLO to MLIR when Shardy propagation is run. PiperOrigin-RevId: 658040672 --- xla/pjrt/BUILD | 1 + xla/pjrt/mlir_to_hlo.cc | 2 ++ xla/python/BUILD | 6 ++++++ xla/python/py_client.cc | 22 ++++++++++++++++++++ xla/service/spmd/shardy/sdy_round_trip/BUILD | 2 +- 5 files changed, 32 insertions(+), 1 deletion(-) diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index 1b6716882b2527..1da8c8242e0fef 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -613,6 +613,7 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", + "@shardy//shardy/dialect/sdy/ir:dialect", "@stablehlo//:chlo_ops", "@stablehlo//:register", "@stablehlo//:stablehlo_ops", diff --git a/xla/pjrt/mlir_to_hlo.cc b/xla/pjrt/mlir_to_hlo.cc index e97c1f18b391ae..58a6b48fdc23a3 100644 --- a/xla/pjrt/mlir_to_hlo.cc +++ b/xla/pjrt/mlir_to_hlo.cc @@ -50,6 +50,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" +#include "shardy/dialect/sdy/ir/utils.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/Register.h" #include "stablehlo/dialect/Serialization.h" @@ -126,6 +127,7 @@ absl::StatusOr> ParseMlirModuleString( registry.insert(); mlir::func::registerAllExtensions(registry); mlir::mhlo::registerAllMhloDialects(registry); + mlir::sdy::loadAllRequiredDialects(&context); mlir::stablehlo::registerAllDialects(registry); context.appendDialectRegistry(registry); diff --git a/xla/python/BUILD b/xla/python/BUILD index 93ecd1a9a7dd6b..755bb0032afae1 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -369,7 +369,9 @@ cc_library( "@com_google_absl//absl/types:variant", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", "@local_config_python//:python_headers", # buildcleaner: keep "//xla:comparison_util", "//xla:literal", @@ -411,8 +413,12 @@ cc_library( "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:platform_util", + "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", + "//xla/service/spmd/shardy/sdy_round_trip:pipelines", "//xla/tsl/concurrency:ref_count", "//xla/tsl/framework:allocator", + "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "//xla/tsl/python/lib/core:numpy", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:errors", diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index 716c59ba9a2c5d..0e36a346f67e39 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -36,9 +36,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep #include "nanobind/stl/pair.h" // IWYU pragma: keep @@ -86,9 +88,13 @@ limitations under the License. #include "xla/python/types.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +#include "xla/service/spmd/shardy/utils.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" #include "xla/util.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" @@ -437,6 +443,22 @@ PyClient::CompileIfrtProgram( mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + mlir::PassManager pm(&context); + // Since Shardy is inside the middle of the XLA pipeline, after converting + // down to HLO, we need to run the Shardy export pipeline to preserve the + // SDY ops and sharding attributes for when we come back from HLO to MLIR + // when Shardy propagation is run. + xla::sdy::addSdyRoundTripExportPipeline(pm); + // TODO(bartchr): remove setting `kPythonIntegrationComplete` in follow-up + // now that both JAX and PartIR are integrated with Shardy. + xla::sdy::addFrontendAttribute(*module, + xla::sdy::kPythonIntegrationComplete, + mlir::StringAttr::get(&context, "t")); + TF_RETURN_IF_ERROR( + tsl::StatusScopedDiagnosticHandler(&context).consumeStatus( + pm.run(*module))); + } return CompileIfrtProgram( client, std::make_unique(module.get()), MakeIfrtCompileOptions(std::move(options), std::move(host_callbacks))); diff --git a/xla/service/spmd/shardy/sdy_round_trip/BUILD b/xla/service/spmd/shardy/sdy_round_trip/BUILD index ab41ada3a7e058..1bd4ee759831ac 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -11,7 +11,7 @@ package_group( packages = [ "//learning/deepmind/partir/compiler/shardonnay/...", "//third_party/openxla/shardy/tools/...", - "//xla/service/spmd/shardy/...", + "//xla/...", ], ) From cd059986af62f768e870322cfd9bca92d40392d7 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 31 Jul 2024 09:48:49 -0700 Subject: [PATCH 328/376] [xla:cpu] Optimize KernelThunk for small number of arguments and results MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_SelectAndScatterF32/128/process_time 373µs ± 2% 337µs ± 2% -9.74% BM_SelectAndScatterF32/256/process_time 1.54ms ± 3% 1.39ms ± 4% -10.04% BM_SelectAndScatterF32/512/process_time 7.08ms ± 7% 6.42ms ± 6% -9.29% PiperOrigin-RevId: 658041870 --- xla/service/cpu/runtime/kernel_thunk.cc | 204 +++++++++++++++++------- xla/service/cpu/runtime/kernel_thunk.h | 106 +++++++++--- 2 files changed, 228 insertions(+), 82 deletions(-) diff --git a/xla/service/cpu/runtime/kernel_thunk.cc b/xla/service/cpu/runtime/kernel_thunk.cc index 88847041c3b51f..50b42c71c7a5ef 100644 --- a/xla/service/cpu/runtime/kernel_thunk.cc +++ b/xla/service/cpu/runtime/kernel_thunk.cc @@ -15,19 +15,19 @@ limitations under the License. #include "xla/service/cpu/runtime/kernel_thunk.h" -#include - #define EIGEN_USE_THREADS #include +#include #include #include #include #include +#include #include +#include "absl/base/attributes.h" #include "absl/base/optimization.h" -#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" @@ -51,50 +51,109 @@ limitations under the License. #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { +namespace internal { -absl::StatusOr> KernelThunk::Create( - Info info, absl::Span arguments_buffers, +// Checks that all buffers are aligned to the minimum alignment. We codegen +// with the assumption that all buffers are aligned, and if they are not, we +// will crash with a segmentation fault, or worse, produce incorrect results. +static absl::Status CheckBufferAlignment( + const Thunk::Info& info, uint64_t min_alignment, + absl::Span kernel_args) { + if (min_alignment == 0) return absl::OkStatus(); + + for (int64_t i = 0; i < kernel_args.size(); ++i) { + auto ptr = reinterpret_cast(kernel_args[i].data); + if (ABSL_PREDICT_FALSE((ptr & (min_alignment - 1)) != 0)) { + return Internal( + "Host kernel %s buffer argument #%d (%p) is not aligned to a " + "required minimum alignment of %d bytes", + info.op_name, i, kernel_args[i].data, min_alignment); + } + } + + return absl::OkStatus(); +} + +// VLOGs kernel arguments resolved from the buffer allocations. +static void VlogKernelArgs( + absl::Span arguments_buffers, absl::Span results_buffers, - std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment) { - if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { - return Internal("Host kernel %s minimum alignment %d is not a power of 2", - info.op_name, *min_alignment); + absl::Span kernel_args) { + for (int64_t i = 0; i < arguments_buffers.size(); ++i) { + VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", i, + arguments_buffers[i].ToString(), + kernel_args[i].data); } + for (int64_t i = 0; i < results_buffers.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " res #%d: %s (%p)", i, results_buffers[i].ToString(), + kernel_args[arguments_buffers.size() + i].data); + } +} - return absl::WrapUnique( - new KernelThunk(std::move(info), arguments_buffers, results_buffers, - std::move(kernel_name), thread_dim, min_alignment)); +// Returns kernel buffer uses for a given arguments and results buffers. +static Thunk::BufferUses KernelBufferUses( + absl::Span arguments_buffers, + absl::Span results_buffers) { + Thunk::BufferUses buffer_uses; + for (const BufferAllocation::Slice& buffer : arguments_buffers) { + buffer_uses.emplace_back(buffer, BufferUse::kRead); + } + for (const BufferAllocation::Slice& buffer : results_buffers) { + buffer_uses.emplace_back(buffer, BufferUse::kWrite); + } + return buffer_uses; } -KernelThunk::KernelThunk( +template +KernelThunk::KernelThunk( Info info, absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, std::optional min_alignment) : Thunk(Kind::kKernel, std::move(info)), - arguments_buffers_(arguments_buffers.begin(), arguments_buffers.end()), - results_buffers_(results_buffers.begin(), results_buffers.end()), num_kernel_args_(arguments_buffers.size() + results_buffers.size()), kernel_name_(std::move(kernel_name)), thread_dim_(thread_dim), min_alignment_(min_alignment), call_once_(thread_dim_ == se::ThreadDim()), kernel_ptr_(nullptr) { + // Resize storage for arguments and results buffers if it is dynamic. + if constexpr (IsDynamic(num_arguments)) { + arguments_buffers_.resize(arguments_buffers.size()); + } + if constexpr (IsDynamic(num_results)) { + results_buffers_.resize(results_buffers.size()); + } + + // Copy buffers from the arguments and results. + for (size_t i = 0; i < arguments_buffers.size(); ++i) { + arguments_buffers_[i] = arguments_buffers[i]; + } + for (size_t i = 0; i < results_buffers.size(); ++i) { + results_buffers_[i] = results_buffers[i]; + } + + // Resize storage for kernel arguments if it is dynamic. + if constexpr (IsDynamic(num_arguments) || IsDynamic(num_results)) { + kernel_args_.resize(num_kernel_args_); + } + // Initialize kernel arguments with null pointers and known buffer sizes. // We'll use them as a template to resolve buffer addresses at run time. - kernel_args_.reserve(num_kernel_args_); - for (const BufferAllocation::Slice& buffer : arguments_buffers_) { - kernel_args_.emplace_back( - SE_HOST_KernelArg{nullptr, static_cast(buffer.size())}); + for (size_t i = 0; i < arguments_buffers.size(); ++i) { + kernel_args_[i] = SE_HOST_KernelArg{ + nullptr, static_cast(arguments_buffers_[i].size())}; } - for (const BufferAllocation::Slice& buffer : results_buffers_) { - kernel_args_.emplace_back( - SE_HOST_KernelArg{nullptr, static_cast(buffer.size())}); + for (size_t i = 0; i < results_buffers.size(); ++i) { + kernel_args_[arguments_buffers_.size() + i] = SE_HOST_KernelArg{ + nullptr, static_cast(results_buffers_[i].size())}; } } -tsl::AsyncValueRef KernelThunk::Execute( +template +ABSL_ATTRIBUTE_ALWAYS_INLINE tsl::AsyncValueRef +KernelThunk::ExecuteInternal( const ExecuteParams& params) { tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); @@ -104,7 +163,7 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector kernel_args = kernel_args_; + KernelArgs kernel_args = kernel_args_; SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data(); const BufferAllocations* allocations = params.buffer_allocations; @@ -130,12 +189,13 @@ tsl::AsyncValueRef KernelThunk::Execute( } if (ABSL_PREDICT_FALSE(VLOG_IS_ON(3))) { - VlogKernelArgs(kernel_args); + VlogKernelArgs(arguments_buffers_, results_buffers_, kernel_args); } // Сheck that all resolved buffers are properly aligned. if constexpr (ShouldCheckBufferSlices()) { - TF_RETURN_IF_ERROR(CheckBufferAlignment(kernel_args)); + TF_RETURN_IF_ERROR( + CheckBufferAlignment(info(), min_alignment_.value_or(0), kernel_args)); } // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk @@ -173,45 +233,67 @@ tsl::AsyncValueRef KernelThunk::Execute( return OkExecuteEvent(); } -absl::Status KernelThunk::CheckBufferAlignment( - absl::Span kernel_args) { - if (min_alignment_.has_value()) { - for (int64_t i = 0; i < num_kernel_args_; ++i) { - auto ptr = reinterpret_cast(kernel_args[i].data); - if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { - return Internal( - "Host kernel %s buffer argument #%d (%p) is not aligned to a " - "required minimum alignment of %d bytes", - info().op_name, i, kernel_args[i].data, *min_alignment_); - } - } - } - return absl::OkStatus(); +template +Thunk::BufferUses KernelThunk::buffer_uses() const { + return KernelBufferUses(arguments_buffers_, results_buffers_); } -void KernelThunk::VlogKernelArgs( - absl::Span kernel_args) { - for (int64_t i = 0; i < arguments_buffers_.size(); ++i) { - VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", i, - arguments_buffers_[i].ToString(), - kernel_args[i].data); - } - for (int64_t i = 0; i < results_buffers_.size(); ++i) { - VLOG(3) << absl::StreamFormat( - " res #%d: %s (%p)", i, results_buffers_[i].ToString(), - kernel_args[arguments_buffers_.size() + i].data); - } +} // namespace internal + +tsl::AsyncValueRef KernelThunk::Execute( + const Thunk::ExecuteParams& params) { + return Base::ExecuteInternal(params); } -KernelThunk::BufferUses KernelThunk::buffer_uses() const { - BufferUses buffer_uses; - for (const BufferAllocation::Slice& buffer : arguments_buffers_) { - buffer_uses.emplace_back(buffer, BufferUse::kRead); - } - for (const BufferAllocation::Slice& buffer : results_buffers_) { - buffer_uses.emplace_back(buffer, BufferUse::kWrite); +template +tsl::AsyncValueRef +SmallKernelThunk::Execute( + const Thunk::ExecuteParams& params) { + return Base::ExecuteInternal(params); +} + +absl::StatusOr> KernelThunk::Create( + Thunk::Info info, + absl::Span arguments_buffers, + absl::Span results_buffers, + std::string kernel_name, se::ThreadDim thread_dim, + std::optional min_alignment) { + if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { + return Internal("Host kernel %s minimum alignment %d is not a power of 2", + info.op_name, *min_alignment); } - return buffer_uses; + + auto small_kernel_thunk = [&](auto num_arguments, auto num_results) { + return absl::WrapUnique( + new SmallKernelThunk( + std::move(info), arguments_buffers, results_buffers, + std::move(kernel_name), thread_dim, min_alignment)); + }; + + static constexpr auto _0 = std::integral_constant{}; + static constexpr auto _1 = std::integral_constant{}; + static constexpr auto _2 = std::integral_constant{}; + static constexpr auto _3 = std::integral_constant{}; + static constexpr auto _4 = std::integral_constant{}; + static constexpr auto _5 = std::integral_constant{}; + static constexpr auto _6 = std::integral_constant{}; + + std::pair params(arguments_buffers.size(), + results_buffers.size()); + + // Return SmallKernelThunk specializations for the most common cases. + if (params == std::make_pair(_0(), _1())) return small_kernel_thunk(_0, _1); + if (params == std::make_pair(_1(), _1())) return small_kernel_thunk(_1, _1); + if (params == std::make_pair(_2(), _1())) return small_kernel_thunk(_2, _1); + if (params == std::make_pair(_3(), _1())) return small_kernel_thunk(_3, _1); + if (params == std::make_pair(_4(), _1())) return small_kernel_thunk(_4, _1); + if (params == std::make_pair(_5(), _1())) return small_kernel_thunk(_5, _1); + if (params == std::make_pair(_6(), _1())) return small_kernel_thunk(_6, _1); + + // Return a generic KernelThunk for dynamic numbers of arguments and results. + return absl::WrapUnique( + new KernelThunk(std::move(info), arguments_buffers, results_buffers, + std::move(kernel_name), thread_dim, min_alignment)); } } // namespace xla::cpu diff --git a/xla/service/cpu/runtime/kernel_thunk.h b/xla/service/cpu/runtime/kernel_thunk.h index 871176ba73ec5b..134602f99537b5 100644 --- a/xla/service/cpu/runtime/kernel_thunk.h +++ b/xla/service/cpu/runtime/kernel_thunk.h @@ -16,17 +16,19 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ #define XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ +#include +#include #include #include #include #include #include #include +#include #include #include "absl/base/thread_annotations.h" #include "absl/container/inlined_vector.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -39,36 +41,64 @@ limitations under the License. namespace xla::cpu { -// Launches compiled host kernel on the caller thread. -class KernelThunk final : public Thunk { - public: - static absl::StatusOr> Create( - Info info, absl::Span arguments_buffers, - absl::Span results_buffers, - std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment = std::nullopt); +// Forward declare thunk defined below. +class KernelThunk; - tsl::AsyncValueRef Execute(const ExecuteParams& params) final; +namespace internal { +// If the number of kernel parameters (arguments and results) is unknown at +// compile time, we use this value to indicate that the parameter is dynamic. +inline constexpr int64_t kDynamicKernelParameter = -1; + +// A base template for a KernelThunk that can be specialized for a statically +// known number of arguments and results. We go extra mile here to optimize +// host kernel dispatching on the hot execution path to minimize the XLA runtime +// overheads for the smallest HLO modules. +template +class KernelThunk : public Thunk { + public: BufferUses buffer_uses() const final; + protected: + tsl::AsyncValueRef ExecuteInternal(const ExecuteParams& params); + private: + friend class ::xla::cpu::KernelThunk; + + static constexpr bool IsDynamic(size_t n) { + return n == kDynamicKernelParameter; + } + + static constexpr size_t Size(int64_t size) { + return std::max(size, 0); + } + + // If we know the number of arguments and results at compile time, we use + // std::array with a fixed size, which allows compiler to automatically unroll + // all the loops on a hot path. + + using ArgumentsBuffers = std::conditional_t< + IsDynamic(num_arguments), std::vector, + std::array>; + + using ResultsBuffers = std::conditional_t< + IsDynamic(num_results), std::vector, + std::array>; + + using KernelArgs = std::conditional_t< + IsDynamic(num_arguments) || IsDynamic(num_results), + absl::InlinedVector, + std::array>; + KernelThunk(Info info, absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, std::optional min_alignment); - // Checks that all buffers are aligned to the minimum alignment. We codegen - // with the assumption that all buffers are aligned, and if they are not, we - // will crash with a segmentation fault, or worse, produce incorrect results. - absl::Status CheckBufferAlignment( - absl::Span kernel_args); - - void VlogKernelArgs(absl::Span kernel_args); - - std::vector arguments_buffers_; - std::vector results_buffers_; + ArgumentsBuffers arguments_buffers_; + ResultsBuffers results_buffers_; size_t num_kernel_args_; @@ -88,7 +118,41 @@ class KernelThunk final : public Thunk { // Pre-initialized kernel arguments that are updated with memory addresses // before the kernel launch. - absl::InlinedVector kernel_args_; + KernelArgs kernel_args_; +}; + +} // namespace internal + +// Kernel thunk specialization for a small kernel with a statically known number +// of arguments and results. +template +class SmallKernelThunk final + : public internal::KernelThunk { + using Base = internal::KernelThunk; + + public: + using Base::Base; + + tsl::AsyncValueRef Execute( + const Thunk::ExecuteParams& params) final; +}; + +// Kernel thunk specialization for dynamic number of arguments and results. +class KernelThunk final : public internal::KernelThunk<> { + using Base = internal::KernelThunk<>; + + public: + using Base::Base; + + static absl::StatusOr> Create( + Thunk::Info info, + absl::Span arguments_buffers, + absl::Span results_buffers, + std::string kernel_name, se::ThreadDim thread_dim, + std::optional min_alignment = std::nullopt); + + tsl::AsyncValueRef Execute( + const Thunk::ExecuteParams& params) final; }; } // namespace xla::cpu From 3107cb2efcb02879fd85be64a39a15792d21a080 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 31 Jul 2024 10:23:20 -0700 Subject: [PATCH 329/376] Adds Python bindings for synthesizing tracebacks directly from Traceback::Frame. PiperOrigin-RevId: 658054678 --- xla/python/traceback.cc | 28 +++++++++++++++++ xla/python/xla_client_test.py | 43 +++++++++++++++++++++++++++ xla/python/xla_extension/__init__.pyi | 7 +++++ 3 files changed, 78 insertions(+) diff --git a/xla/python/traceback.cc b/xla/python/traceback.cc index 5a86924002f2b0..b3cdfce0770745 100644 --- a/xla/python/traceback.cc +++ b/xla/python/traceback.cc @@ -222,6 +222,7 @@ PyType_Slot traceback_slots_[] = { void BuildTracebackSubmodule(nb::module_& m) { nb::class_(m, "Frame") + .def(nb::init()) .def_ro("file_name", &Traceback::Frame::file_name) .def_ro("function_name", &Traceback::Frame::function_name) .def_ro("function_start_line", &Traceback::Frame::function_start_line) @@ -271,6 +272,33 @@ void BuildTracebackSubmodule(nb::module_& m) { traceback.def("__str__", &Traceback::ToString); traceback.def("as_python_traceback", &Traceback::AsPythonTraceback); + traceback.def_static( + "traceback_from_frames", + [](std::vector frames) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type( + reinterpret_cast(&PyTraceBack_Type)); + for (const Traceback::Frame& frame : frames) { + PyCodeObject* py_code = + PyCode_NewEmpty(frame.file_name.c_str(), + frame.function_name.c_str(), frame.line_num); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/ + nullptr); + Py_DECREF(py_code); + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + frame.line_num); + } + return traceback; + }, + "Creates a traceback from a list of frames."); + traceback.def_static( "code_addr2line", [](nb::handle code, int lasti) { diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index b84e094b1d841b..65d6d7f3c749d2 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -2990,6 +2990,49 @@ def testAccessingLocalsDoesNotCrash(self): for frame, _ in traceback.walk_tb(python_tb): _ = frame.f_locals # should not crash + def testTracebackFromFrames(self): + def FooFn(x): + return x + 1 + + def BarFn(y): + y = y + 1 + y = y + 2 + return y * 2 + + frame_foo = xla_client.Frame( + __file__, + FooFn.__code__.co_name, + FooFn.__code__.co_firstlineno, + FooFn.__code__.co_firstlineno + 1, + ) + frame_bar = xla_client.Frame( + __file__, + BarFn.__code__.co_name, + BarFn.__code__.co_firstlineno, + BarFn.__code__.co_firstlineno + 2, + ) + frames = [frame_foo, frame_bar] + tb = xla_client.Traceback.traceback_from_frames(frames) + + with self.subTest("WalkDoesNotError"): + for frame, _ in traceback.walk_tb(tb): + _ = frame.f_locals # should not crash + + with self.subTest("TracebackCorrectness"): + tb_string = traceback.format_tb(tb) + # The traceback should have the format: + # File , line N in BarFn + # y = y + 2 + # File , line N in FooFn + # return x + 1 + self.assertLen(tb_string, len(frames)) + bar_frame = tb_string[0].split("\n") + self.assertEndsWith(bar_frame[0], "BarFn") + self.assertEqual(bar_frame[1].strip(), "y = y + 2") + foo_frame = tb_string[1].split("\n") + self.assertEndsWith(foo_frame[0], "FooFn") + self.assertEqual(foo_frame[1].strip(), "return x + 1") + tests.append(TracebackTest) class ClientTest(ComputationTest): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index a0e9008f81de68..5e2982184a3fca 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -751,12 +751,19 @@ class Frame: function_name: str function_line_start: int line_num: int + def __init__(self, + file_name: str, + function_name: str, + function_line_start: int, + line_num: int): ... def __repr__(self) -> str: ... class Traceback: enabled: ClassVar[bool] @staticmethod def get_traceback() -> Traceback: ... + @staticmethod + def traceback_from_frames(frames: Sequence[Frame]) -> Any: ... frames: Sequence[Frame] def __str__(self) -> str: ... def as_python_traceback(self) -> Any: ... From 2556f9f9363d3354757b1e2d073635d32099474a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 10:29:00 -0700 Subject: [PATCH 330/376] Update comment to better track code cleanup PiperOrigin-RevId: 658056617 --- xla/python/ifrt/memory.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xla/python/ifrt/memory.h b/xla/python/ifrt/memory.h index a5f48fa6cf432c..a3117e5e3049d7 100644 --- a/xla/python/ifrt/memory.h +++ b/xla/python/ifrt/memory.h @@ -81,8 +81,8 @@ class MemoryKind { // indicated by the device, simply returns `MemoryKind` with no memory kind // chosen. // -// TODO(hyeontaek,yashkatariya): Harden `MemoryKind` creation paths so that -// every `MemoryKind` is canonicalized and does not require on-demand +// TODO(b/356623715): Harden `MemoryKind` creation paths so that every +// `MemoryKind` is canonicalized and does not require on-demand // canonicalization. MemoryKind CanonicalizeMemoryKind(MemoryKind memory_kind, Device* device); From bc35df357c168b82a7e0afaaf5588f3b612f9452 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 31 Jul 2024 10:47:53 -0700 Subject: [PATCH 331/376] [xla:cpu] Optimize ThunkExecutor::Execute part #1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit name old cpu/op new cpu/op delta BM_SelectAndScatterF32/128/process_time 889µs ± 1% 740µs ± 3% -16.70% BM_SelectAndScatterF32/256/process_time 3.64ms ± 2% 3.00ms ± 1% -17.64% BM_SelectAndScatterF32/512/process_time 15.3ms ± 1% 13.1ms ± 3% -14.61% PiperOrigin-RevId: 658063846 --- xla/service/cpu/runtime/thunk_executor.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index 155edd9bdbca66..f25fd6119a284d 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -162,6 +162,12 @@ tsl::AsyncValueRef ThunkExecutor::Execute( Execute(state.get(), params, ReadyQueue(source_.begin(), source_.end()), /*lock=*/params.session.Join()); + // If execution already completed (all kernels executed in the caller thread), + // immediately return the result to avoid wasteful reference counting below. + if (ABSL_PREDICT_TRUE(state->execute_event.IsAvailable())) { + return std::move(state->execute_event); + } + // Move execute state to the execute event callback to ensure that it is kept // alive while thunk executor has pending tasks. auto execute_event = state->execute_event; From 1459593b42caa9edc3c20decc8777e3cb8ebc0fb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 10:50:52 -0700 Subject: [PATCH 332/376] Allow generation of partially replicated strategies for iota ops. PiperOrigin-RevId: 658065015 --- .../auto_sharding/auto_sharding_strategy.cc | 5 +-- .../auto_sharding/auto_sharding_test.cc | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 4c6311a111f467..dd6d38c01bbe35 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -738,16 +738,13 @@ BuildStrategyAndCost( break; } case HloOpcode::kIota: { - // For an unknown reason, we do not generate partially replicated - // strategies for iota ops. This can be changed if we find that our - // search isn't exhaustive enough for certain ops. strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, /* create_replicated_strategies */ true, - /* create_partially_replicated_strategies */ false) + /* create_partially_replicated_strategies */ true) .value(); break; } diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index cb595afaf93569..c899a0c6d8eaf2 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -346,6 +346,40 @@ ENTRY %elementwise { op::Sharding("{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}"))); } +TEST_F(AutoShardingTest, IotaPartiallyReplicatedShardingTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %elementwise { + iota1 = s32[11,1026]{1,0} iota(), iota_dimension=1 + param1 = s32[11,1026]{1,0} parameter(0), sharding={devices=[1,16,16]<=[16,16]T(1,0) last_tile_dim_replicate} + copy1 = s32[11,1026]{1,0} copy(iota1) + ROOT add1 = s32[11,1026]{1,0} add(copy1, param1) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + AutoSharding( + /* option */ { + .enable = true, + .preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .only_allow_divisible_input_output = false, + .device_mesh_shape = {16, 16}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}}) + .Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* iota = FindInstruction(module.get(), "iota1"); + ASSERT_NE(iota, nullptr); + EXPECT_THAT( + iota, op::Sharding( + "{devices=[1,16,16]<=[16,16]T(1,0) last_tile_dim_replicate}")); +} + TEST_F(AutoShardingTest, SliceMixedUserShardingTest) { constexpr absl::string_view kHloString = R"( HloModule module From cbf432f48c6fb026addf982d670205e6b1536eb3 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 31 Jul 2024 10:51:16 -0700 Subject: [PATCH 333/376] [xla:cpu] Disable OneDNN rewrites when XLA:CPU thunks runtime is enabled PiperOrigin-RevId: 658065186 --- xla/service/cpu/cpu_compiler.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index ebac5206086138..53be150db98ee0 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -521,7 +521,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( // Rewrite to custom calls with target as oneDNN library calls. #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) // AOT compiled code runs in single thread. - if (!is_aot_compile) { + bool is_thunk_runtime = debug_options.xla_cpu_use_thunk_runtime(); + if (!is_aot_compile && !is_thunk_runtime) { // Placing OneDnnOpsRewriter here to match the flax patterns // TODO: Decide where would be the appropriate place for this pass to make // it more generic @@ -541,7 +542,7 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( FloatSupport bf16_support(BF16); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) CpuFloatSupport onednn_bf16_support(BF16); - if (!is_aot_compile) { + if (!is_aot_compile && !is_thunk_runtime) { pipeline.AddPass(&onednn_bf16_support); } else { pipeline.AddPass(&bf16_support); @@ -746,8 +747,11 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( : tsl::port::NumSchedulableCPUs(); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + auto& debug_options = module->config().debug_options(); + bool is_thunk_runtime = debug_options.xla_cpu_use_thunk_runtime(); + // AOT compiled code runs in single thread. - if (!is_aot_compile) { + if (!is_aot_compile && !is_thunk_runtime) { auto debug_options = module->config().debug_options(); // Run SimplifyFPConversions pass to simplify the BF16 pattern and make it // easier to match. From 1b09c08ac6d5767b5fef986696683698b66c4ef7 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Wed, 31 Jul 2024 10:51:55 -0700 Subject: [PATCH 334/376] Add an overload of `InferDotOperandSharding` that takes shardings as input. The added function is helpful when we do not have the instruction. PiperOrigin-RevId: 658065441 --- xla/hlo/utils/hlo_sharding_util.cc | 105 +++++++++++++++--------- xla/hlo/utils/hlo_sharding_util.h | 8 ++ xla/hlo/utils/hlo_sharding_util_test.cc | 51 +++++++++++- xla/service/dot_as_convolution_util.cc | 7 ++ xla/service/dot_as_convolution_util.h | 6 +- xla/service/spmd/convolution_handler.cc | 41 +-------- xla/service/spmd/dot_handler.cc | 43 ++-------- 7 files changed, 145 insertions(+), 116 deletions(-) diff --git a/xla/hlo/utils/hlo_sharding_util.cc b/xla/hlo/utils/hlo_sharding_util.cc index 20213b587ec37f..a4e8c6a4a1a56a 100644 --- a/xla/hlo/utils/hlo_sharding_util.cc +++ b/xla/hlo/utils/hlo_sharding_util.cc @@ -2118,7 +2118,7 @@ std::optional TransposeShardingWithCollapsedDims( << "Sharding transpose should not move subgroup dims before data dims."; perm[src_to_tgt[i] - skipped_tgt_dims + skipped_src_dims] = i; } - auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); + auto tgt_sharding = TransposeSharding(source, perm); DimensionVector tgt_tiles(tgt_to_src.size(), 1); for (int64_t i = 0; i < tgt_tiles.size(); ++i) { if (tgt_to_src[i] >= 0) { @@ -2508,8 +2508,8 @@ HloSharding InferGatherScatterParallelShardingFromOperandSharding( operand_sharding.tile_assignment().dim(operand_idx); } HloSharding replicate_non_parallel_dims = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - operand_sharding, operand_non_parallel_dims); + PartiallyReplicateTiledShardingOnDims(operand_sharding, + operand_non_parallel_dims); if (replicate_non_parallel_dims.IsTileMaximal()) { return replicate_non_parallel_dims; } @@ -3336,15 +3336,13 @@ std::optional ReturnImprovedShardingImpl( return std::nullopt; } int64_t sharding_tiles = from.NumTiles(); - if (hlo_sharding_util::MergeSharding(*to_improved, &from, - may_combine_partial_sharding)) { + if (MergeSharding(*to_improved, &from, may_combine_partial_sharding)) { // Override existing tiled sharding only when the new sharding is compatible // with the existing one. This avoids unexpected resharding when `sharding` // just has more tiles than existing sharding but they are not mergeable. if (!allow_aggressive_resharding && to_improved_shape.IsArray() && !to_improved->IsTileMaximal() && from.NumTiles() == sharding_tiles) { - if (!hlo_sharding_util::IsSubTilingOrEqualSharding(to_improved_shape, - from, *to_improved)) { + if (!IsSubTilingOrEqualSharding(to_improved_shape, from, *to_improved)) { VLOG(10) << "Not merging because of different device distribution"; VLOG(10) << "Instr sharding: " << to_improved->ToString(); VLOG(10) << "New sharding " << from.ToString(); @@ -3357,16 +3355,13 @@ std::optional ReturnImprovedShardingImpl( } HloSharding InferDotOperandSharding( - const HloInstruction* dot, int64_t operand_index, + const HloSharding* dot_sharding, const HloSharding* other_operand_sharding, + int64_t operand_index, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, bool consider_other_operand, bool may_combine_partial_sharding) { - CHECK(dot->opcode() == HloOpcode::kDot || - dot->opcode() == HloOpcode::kConvolution); CHECK(operand_index == 0 || operand_index == 1); CHECK(dnums.conv_spatial_dims.empty()); - auto operand = dot->operand(operand_index); - auto other = dot->operand(1 - operand_index); std::vector output_dims_to_replicate; std::vector other_operand_dims_to_replicate; for (const auto& dim : operand_index == 0 ? dnums.rhs_non_contracting_dims @@ -3391,33 +3386,47 @@ HloSharding InferDotOperandSharding( other_operand_dims_to_replicate.push_back(other_dim); } } - HloSharding output_other_dims_replicated = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - dot->sharding(), output_dims_to_replicate); - std::vector output_to_operand_dims(dot->shape().rank(), -1); - std::vector operand_to_output_dims(operand->shape().rank(), -1); - for (const auto& dim : dnums.batch_dims) { - output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; - operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; - } - for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims - : dnums.rhs_non_contracting_dims) { - output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; - operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; + int64_t operand_shape_rank = + operand_index == 0 ? dnums.lhs_shape_rank : dnums.rhs_shape_rank; + int64_t other_shape_rank = + operand_index == 0 ? dnums.rhs_shape_rank : dnums.lhs_shape_rank; + + HloSharding sharding = HloSharding::Replicate(); + + if (dot_sharding != nullptr) { + HloSharding output_other_dims_replicated = + PartiallyReplicateTiledShardingOnDims(*dot_sharding, + output_dims_to_replicate); + + std::vector output_to_operand_dims(dnums.output_shape_rank, -1); + std::vector operand_to_output_dims(operand_shape_rank, -1); + for (const auto& dim : dnums.batch_dims) { + output_to_operand_dims[dim.output] = + operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = + dim.output; + } + for (const auto& dim : operand_index == 0 + ? dnums.lhs_non_contracting_dims + : dnums.rhs_non_contracting_dims) { + output_to_operand_dims[dim.output] = + operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = + dim.output; + } + sharding = std::move(*TransposeShardingWithCollapsedDims( + output_other_dims_replicated, output_to_operand_dims, + operand_to_output_dims)); } - auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims( - output_other_dims_replicated, output_to_operand_dims, - operand_to_output_dims); - if (consider_other_operand && - hlo_sharding_util::IsSpatiallyPartitioned(other)) { - auto other_operand_dims_replicated = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - other->sharding(), other_operand_dims_to_replicate); + if (consider_other_operand && other_operand_sharding != nullptr && + IsSpatiallyPartitioned(*other_operand_sharding)) { + auto other_operand_dims_replicated = PartiallyReplicateTiledShardingOnDims( + *other_operand_sharding, other_operand_dims_to_replicate); - std::vector other_to_operand_dims(other->shape().rank(), -1); - std::vector operand_to_other_dims(operand->shape().rank(), -1); + std::vector other_to_operand_dims(other_shape_rank, -1); + std::vector operand_to_other_dims(operand_shape_rank, -1); for (const auto& dim : dnums.batch_dims) { other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] = operand_index == 0 ? dim.lhs : dim.rhs; @@ -3430,12 +3439,11 @@ HloSharding InferDotOperandSharding( operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = operand_index == 0 ? dim.rhs : dim.lhs; } - HloSharding sharding_from_other = - *hlo_sharding_util::TransposeShardingWithCollapsedDims( - other_operand_dims_replicated, other_to_operand_dims, - operand_to_other_dims); - if (hlo_sharding_util::MergeSharding(sharding, &sharding_from_other, - may_combine_partial_sharding)) { + HloSharding sharding_from_other = *TransposeShardingWithCollapsedDims( + other_operand_dims_replicated, other_to_operand_dims, + operand_to_other_dims); + if (MergeSharding(sharding, &sharding_from_other, + may_combine_partial_sharding)) { sharding = std::move(sharding_from_other); } } @@ -3443,5 +3451,20 @@ HloSharding InferDotOperandSharding( return sharding; } +HloSharding InferDotOperandSharding( + const HloInstruction* dot, int64_t operand_index, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, + bool consider_other_operand, bool may_combine_partial_sharding) { + CHECK(dot->opcode() == HloOpcode::kDot || + dot->opcode() == HloOpcode::kConvolution); + + const HloInstruction* other_operand = dot->operand(1 - operand_index); + return InferDotOperandSharding( + dot->has_sharding() ? &dot->sharding() : nullptr, + other_operand->has_sharding() ? &other_operand->sharding() : nullptr, + operand_index, dnums, consider_other_operand, + may_combine_partial_sharding); +} + } // namespace hlo_sharding_util } // namespace xla diff --git a/xla/hlo/utils/hlo_sharding_util.h b/xla/hlo/utils/hlo_sharding_util.h index 1df5aebf107829..335cb6b53fe46b 100644 --- a/xla/hlo/utils/hlo_sharding_util.h +++ b/xla/hlo/utils/hlo_sharding_util.h @@ -539,6 +539,14 @@ HloSharding InferDotOperandSharding( const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, bool consider_other_operand, bool may_combine_partial_sharding); +// Same as above, but takes the sharding of the dot and the other operand as +// input. +HloSharding InferDotOperandSharding( + const HloSharding* dot_sharding, const HloSharding* other_operand_sharding, + int64_t operand_index, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, + bool consider_other_operand, bool may_combine_partial_sharding); + } // namespace hlo_sharding_util } // namespace xla diff --git a/xla/hlo/utils/hlo_sharding_util_test.cc b/xla/hlo/utils/hlo_sharding_util_test.cc index 9015726ffd6c86..44ec60cca97172 100644 --- a/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/xla/hlo/utils/hlo_sharding_util_test.cc @@ -1017,7 +1017,7 @@ TEST(HloShardingUtilTest, UntileShape) { using HloShardingUtilTestWithHlo = HloTestBase; -TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest) { +TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest1) { absl::string_view hlo_string = R"( HloModule module @@ -1061,6 +1061,55 @@ TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest) { } } +TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest2) { + absl::string_view hlo_string = R"( + HloModule module + + ENTRY %main.7 { + %p0 = bf16[32,64,128,512] parameter(0), sharding={devices=[8,1,1,4]<=[32]} + %p1 = bf16[32,64,256,512] parameter(1), sharding={devices=[1,1,1,2,16]<=[8,2,2]T(1,0,2) last_tile_dim_replicate} + ROOT %dot.3 = bf16[32,64,128,256] dot(%p0, %p1), lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_contracting_dims={3}, sharding={devices=[2,2,2,2,2]<=[32] last_tile_dim_replicate} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + const HloInstruction* dot = module->entry_computation()->root_instruction(); + auto dnums = dot_as_convolution_util::ParseDotGeneralFromDot(dot); + + const HloSharding& lhs_sharding = dot->operand(0)->sharding(); + const HloSharding& rhs_sharding = dot->operand(1)->sharding(); + const HloSharding& dot_sharding = dot->sharding(); + + bool may_combine_partial_sharding = true; + for (int64_t i = 0; i < 2; ++i) { + EXPECT_EQ(InferDotOperandSharding(nullptr, nullptr, i, dnums, true, + may_combine_partial_sharding), + HloSharding::Replicate()); + } + + // If the other_operand_sharding is missing (nullptr), we only infer the + // result from the result. + for (int64_t i = 0; i < 2; ++i) { + EXPECT_EQ(InferDotOperandSharding(&dot_sharding, nullptr, i, dnums, true, + may_combine_partial_sharding), + InferDotOperandSharding(dot, i, dnums, false, + may_combine_partial_sharding)); + } + + EXPECT_EQ(InferDotOperandSharding(nullptr, &rhs_sharding, 0, dnums, true, + may_combine_partial_sharding), + rhs_sharding); + EXPECT_EQ(InferDotOperandSharding(nullptr, &lhs_sharding, 1, dnums, true, + may_combine_partial_sharding), + lhs_sharding); + + EXPECT_EQ(InferDotOperandSharding(nullptr, &rhs_sharding, 0, dnums, false, + may_combine_partial_sharding), + HloSharding::Replicate()); + EXPECT_EQ(InferDotOperandSharding(nullptr, &lhs_sharding, 1, dnums, false, + may_combine_partial_sharding), + HloSharding::Replicate()); +} + } // namespace } // namespace hlo_sharding_util } // namespace xla diff --git a/xla/service/dot_as_convolution_util.cc b/xla/service/dot_as_convolution_util.cc index e22dddcf7cee68..25d6b6a48c9d48 100644 --- a/xla/service/dot_as_convolution_util.cc +++ b/xla/service/dot_as_convolution_util.cc @@ -129,6 +129,9 @@ bool SpatialIsContracting(int64_t lhs_spatial_size, int64_t rhs_spatial_size, } } + dims.lhs_shape_rank = conv->operand(0)->shape().rank(); + dims.rhs_shape_rank = conv->operand(1)->shape().rank(); + dims.output_shape_rank = conv->shape().rank(); return dims; } @@ -224,6 +227,10 @@ DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) { dnums.rhs_non_contracting_dims.back().spatial_dim = -1; } } + + dnums.lhs_shape_rank = dot->operand(0)->shape().rank(); + dnums.rhs_shape_rank = dot->operand(1)->shape().rank(); + dnums.output_shape_rank = dot->shape().rank(); return dnums; } diff --git a/xla/service/dot_as_convolution_util.h b/xla/service/dot_as_convolution_util.h index 01236f8c7ec9d1..9bed16990fc204 100644 --- a/xla/service/dot_as_convolution_util.h +++ b/xla/service/dot_as_convolution_util.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ #define XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ +#include #include -#include #include #include "xla/hlo/ir/hlo_instruction.h" @@ -55,6 +55,10 @@ struct DotConvolutionDimsInfo { std::vector lhs_non_contracting_dims; std::vector rhs_non_contracting_dims; std::vector conv_spatial_dims; + + int64_t lhs_shape_rank; + int64_t rhs_shape_rank; + int64_t output_shape_rank; }; // Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can diff --git a/xla/service/spmd/convolution_handler.cc b/xla/service/spmd/convolution_handler.cc index 3985a81c810b01..a084c2ec98fae6 100644 --- a/xla/service/spmd/convolution_handler.cc +++ b/xla/service/spmd/convolution_handler.cc @@ -1028,43 +1028,8 @@ absl::Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { if (hlo->sharding().HasUniqueDevice()) { return DefaultAction(hlo); } - auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo); - dot_as_convolution_util::DotConvolutionDimsInfo mapping; - for (const auto& dims : dims_info.batch_dims) { - mapping.batch_dims.emplace_back(); - mapping.batch_dims.back().lhs = dims.lhs; - mapping.batch_dims.back().rhs = dims.rhs; - mapping.batch_dims.back().output = dims.output; - mapping.batch_dims.back().spatial_dim = dims.spatial_dim; - } - for (const auto& dims : dims_info.contracting_dims) { - mapping.contracting_dims.emplace_back(); - mapping.contracting_dims.back().lhs = dims.lhs; - mapping.contracting_dims.back().rhs = dims.rhs; - mapping.contracting_dims.back().output = dims.output; - mapping.contracting_dims.back().spatial_dim = dims.spatial_dim; - } - for (const auto& dims : dims_info.lhs_non_contracting_dims) { - mapping.lhs_non_contracting_dims.emplace_back(); - mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.lhs_non_contracting_dims.back().output = dims.output; - mapping.lhs_non_contracting_dims.back().spatial_dim = dims.spatial_dim; - } - for (const auto& dims : dims_info.rhs_non_contracting_dims) { - mapping.rhs_non_contracting_dims.emplace_back(); - mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.rhs_non_contracting_dims.back().output = dims.output; - mapping.rhs_non_contracting_dims.back().spatial_dim = dims.spatial_dim; - } - for (const auto& dims : dims_info.conv_spatial_dims) { - mapping.conv_spatial_dims.emplace_back(); - mapping.conv_spatial_dims.back().lhs = dims.lhs; - mapping.conv_spatial_dims.back().rhs = dims.rhs; - mapping.conv_spatial_dims.back().output = dims.output; - mapping.conv_spatial_dims.back().spatial_dim = dims.spatial_dim; - } + const auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo); + auto create_sharded_conv = [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, spmd::SpmdBuilder* b, @@ -1084,7 +1049,7 @@ absl::Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { } }; - return HandleDotHelper(hlo, mapping, create_sharded_conv); + return HandleDotHelper(hlo, dims_info, create_sharded_conv); } } // namespace spmd diff --git a/xla/service/spmd/dot_handler.cc b/xla/service/spmd/dot_handler.cc index 123edb05a6b8ef..fb41d24870f405 100644 --- a/xla/service/spmd/dot_handler.cc +++ b/xla/service/spmd/dot_handler.cc @@ -68,41 +68,8 @@ using hlo_sharding_util::GroupedSharding; } // namespace absl::Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { - DotConvolutionDimsInfo mapping; - const auto& dnums = hlo->dot_dimension_numbers(); - int64_t next_output_dim = 0; - for (int64_t i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { - mapping.batch_dims.emplace_back(); - mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i); - mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i); - mapping.batch_dims.back().output = next_output_dim++; - } - for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { - mapping.contracting_dims.emplace_back(); - mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i); - mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i); - mapping.contracting_dims.back().output = -1; - } - for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) { - if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || - absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { - continue; - } - mapping.lhs_non_contracting_dims.emplace_back(); - mapping.lhs_non_contracting_dims.back().lhs = i; - mapping.lhs_non_contracting_dims.back().rhs = -1; - mapping.lhs_non_contracting_dims.back().output = next_output_dim++; - } - for (int64_t i = 0; i < hlo->operand(1)->shape().rank(); ++i) { - if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) || - absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { - continue; - } - mapping.rhs_non_contracting_dims.emplace_back(); - mapping.rhs_non_contracting_dims.back().lhs = -1; - mapping.rhs_non_contracting_dims.back().rhs = i; - mapping.rhs_non_contracting_dims.back().output = next_output_dim++; - } + DotConvolutionDimsInfo mapping = + dot_as_convolution_util::ParseDotGeneralFromDot(hlo); HloDotInstruction* dot = Cast(hlo); std::vector sparsity(dot->sparsity().begin(), @@ -3031,6 +2998,9 @@ DotConvolutionDimsInfo ConvertDimNumsWithFeatureGroupCount( const DotConvolutionDimsInfo& dims_mapping, HloInstruction* original_hlo) { const auto& dnums = original_hlo->convolution_dimension_numbers(); DotConvolutionDimsInfo new_dims_mapping; + new_dims_mapping.lhs_shape_rank = dims_mapping.lhs_shape_rank; + new_dims_mapping.rhs_shape_rank = dims_mapping.rhs_shape_rank; + new_dims_mapping.output_shape_rank = dims_mapping.output_shape_rank; new_dims_mapping.batch_dims = dims_mapping.batch_dims; new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims; // Append batch dims. @@ -3060,6 +3030,9 @@ DotConvolutionDimsInfo ConvertDimNumsWithBatchGroupCount( const DotConvolutionDimsInfo& dims_mapping, HloInstruction* original_hlo) { const auto& dnums = original_hlo->convolution_dimension_numbers(); DotConvolutionDimsInfo new_dims_mapping; + new_dims_mapping.lhs_shape_rank = dims_mapping.lhs_shape_rank; + new_dims_mapping.rhs_shape_rank = dims_mapping.rhs_shape_rank; + new_dims_mapping.output_shape_rank = dims_mapping.output_shape_rank; new_dims_mapping.batch_dims = dims_mapping.batch_dims; new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims; new_dims_mapping.contracting_dims = dims_mapping.contracting_dims; From 1f3b222248bb4ca759701831cfe37c1e8d4cb2d9 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 31 Jul 2024 12:57:15 -0700 Subject: [PATCH 335/376] Remove StreamExecutor::Memset interface, and implements its single use in-place in GpuStream::MemZero. PiperOrigin-RevId: 658111357 --- xla/backends/interpreter/executor.h | 5 ----- xla/stream_executor/cuda/cuda_executor.cc | 10 ---------- xla/stream_executor/gpu/gpu_executor.h | 3 --- xla/stream_executor/gpu/gpu_stream.cc | 5 ++++- xla/stream_executor/host/host_executor.cc | 10 ---------- xla/stream_executor/host/host_executor.h | 3 --- xla/stream_executor/mock_stream_executor.h | 4 ---- xla/stream_executor/rocm/rocm_executor.cc | 10 ---------- xla/stream_executor/stream_executor.h | 8 -------- 9 files changed, 4 insertions(+), 54 deletions(-) diff --git a/xla/backends/interpreter/executor.h b/xla/backends/interpreter/executor.h index a899bba70073fb..0aee389be7bf94 100644 --- a/xla/backends/interpreter/executor.h +++ b/xla/backends/interpreter/executor.h @@ -107,11 +107,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { delete[] static_cast(mem); } - absl::Status Memset(Stream *stream, DeviceMemoryBase *location, - uint8_t pattern, uint64_t size) override { - return absl::InternalError("Interpreter can not memset"); - } - // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } absl::Status SynchronousMemZero(DeviceMemoryBase *location, diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 3c268909b461dd..3deea1a870578f 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -658,16 +658,6 @@ absl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, AsCudaDevicePtr(gpu_src), size); } -absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) { - VLOG(2) << "enqueueing memset8 operation onto stream " << stream - << " at location " << location << " with size " << size - << " and pattern " << std::hex << pattern; - return GpuDriver::AsynchronousMemsetUint8(context_, AsCudaDevicePtr(location), - pattern, size, - AsGpuStreamValue(stream)); -} - void GpuExecutor::DeallocateStream(Stream* stream) { { absl::MutexLock lock(&mu_); diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index 7e7b834ebc154f..116120d1e0e62b 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -191,9 +191,6 @@ class GpuExecutor : public StreamExecutorCommon { const DeviceMemoryBase& gpu_src, uint64_t size) override; - absl::Status Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) override; - void DeallocateStream(Stream* stream) override; absl::Status BlockHostUntilDone(Stream* stream) override; diff --git a/xla/stream_executor/gpu/gpu_stream.cc b/xla/stream_executor/gpu/gpu_stream.cc index e1c943a1171d95..040e209441ac51 100644 --- a/xla/stream_executor/gpu/gpu_stream.cc +++ b/xla/stream_executor/gpu/gpu_stream.cc @@ -86,7 +86,10 @@ absl::Status GpuStream::MemZero(DeviceMemoryBase* location, uint64_t size) { size % 4 == 0) { return Memset32(location, 0x0, size); } else { - return parent_->Memset(this, location, 0x0, size); + return GpuDriver::AsynchronousMemsetUint8( + parent_->gpu_context(), + reinterpret_cast(location->opaque()), 0x0, size, + gpu_stream()); } } diff --git a/xla/stream_executor/host/host_executor.cc b/xla/stream_executor/host/host_executor.cc index 8d8eeb7e421de1..38715ce56ed3db 100644 --- a/xla/stream_executor/host/host_executor.cc +++ b/xla/stream_executor/host/host_executor.cc @@ -137,16 +137,6 @@ absl::Status HostExecutor::SynchronousMemZero(DeviceMemoryBase* location, return absl::OkStatus(); } -absl::Status HostExecutor::Memset(Stream* stream, DeviceMemoryBase* location, - uint8 pattern, uint64_t size) { - void* gpu_mem = location->opaque(); - // Enqueue the [asynchronous] memzero on the stream (HostStream) associated - // with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); - return absl::OkStatus(); -} - absl::Status HostExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { diff --git a/xla/stream_executor/host/host_executor.h b/xla/stream_executor/host/host_executor.h index 478ab2778cbe9e..5f1c5d00a23463 100644 --- a/xla/stream_executor/host/host_executor.h +++ b/xla/stream_executor/host/host_executor.h @@ -88,9 +88,6 @@ class HostExecutor : public StreamExecutorCommon { delete[] static_cast(mem); } - absl::Status Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) override; - // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } absl::Status SynchronousMemZero(DeviceMemoryBase* location, diff --git a/xla/stream_executor/mock_stream_executor.h b/xla/stream_executor/mock_stream_executor.h index d0d696cfd1147a..03dd1115f3d6fb 100644 --- a/xla/stream_executor/mock_stream_executor.h +++ b/xla/stream_executor/mock_stream_executor.h @@ -102,10 +102,6 @@ class MockStreamExecutor : public StreamExecutor { (void* host_dst, const DeviceMemoryBase& device_src, uint64_t size), (override)); - MOCK_METHOD(absl::Status, Memset, - (Stream * stream, DeviceMemoryBase* location, uint8_t pattern, - uint64_t size), - (override)); MOCK_METHOD(void, DeallocateStream, (Stream * stream), (override)); MOCK_METHOD(absl::Status, BlockHostUntilDone, (Stream * stream), (override)); MOCK_METHOD(absl::Status, EnablePeerAccessTo, (StreamExecutor * other), diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index cb207eabbed6e1..366d2da583de31 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -490,16 +490,6 @@ absl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, AsROCmDevicePtr(gpu_src), size); } -absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, - uint8 pattern, uint64_t size) { - VLOG(2) << "enqueueing memset8 operation onto stream " << stream - << " at location " << location << " with size " << size - << " and pattern " << std::hex << pattern; - return GpuDriver::AsynchronousMemsetUint8(context_, AsROCmDevicePtr(location), - pattern, size, - AsGpuStreamValue(stream)); -} - void GpuExecutor::DeallocateStream(Stream* stream) { { absl::MutexLock lock(&mu_); diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index b4d83efa8ce2b8..49929d4ce34c11 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -235,14 +235,6 @@ class StreamExecutor { return SynchronousMemcpy(host_dst, device_src, size); } - // Enqueues an operation onto stream to set 8-bit patterns starting at - // location, for byte count given by size. Returns whether the operation was - // successfully enqueued onto the stream. - virtual absl::Status Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) { - return absl::InternalError("Not implemented"); - } - // Deallocates stream resources on the underlying platform. virtual void DeallocateStream(Stream* stream) = 0; From 102003e0f6b34218caaf825bab5abf2f4bee8d54 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Wed, 31 Jul 2024 13:40:53 -0700 Subject: [PATCH 336/376] Add ability to dump expected and actual results for the exhaustive tests This adds the ability to dump the expected and actual results to files to make it easier to see if there are perturbat ions between compiler flags. PiperOrigin-RevId: 658126820 --- xla/tests/client_library_test_base.cc | 6 ++ xla/tests/client_library_test_base.h | 3 + xla/tests/exhaustive/BUILD | 14 +++++ .../exhaustive/exhaustive_op_test_utils.cc | 62 +++++++++++++++++++ .../exhaustive/exhaustive_op_test_utils.h | 15 ++++- xla/tests/exhaustive/exhaustive_test_main.cc | 14 +++++ .../exhaustive_unary_complex_test.cc | 2 + .../exhaustive/exhaustive_unary_f64_test.cc | 2 + 8 files changed, 117 insertions(+), 1 deletion(-) diff --git a/xla/tests/client_library_test_base.cc b/xla/tests/client_library_test_base.cc index 07f46e75715d68..743db05f93f73b 100644 --- a/xla/tests/client_library_test_base.cc +++ b/xla/tests/client_library_test_base.cc @@ -98,6 +98,12 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) ->set_xla_hlo_evaluator_use_fast_path(true); } +std::string ClientLibraryTestBase::SuiteName() const { + return ::testing::UnitTest::GetInstance() + ->current_test_info() + ->test_suite_name(); +} + std::string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } diff --git a/xla/tests/client_library_test_base.h b/xla/tests/client_library_test_base.h index 8246d851052429..800f14c4c014ad 100644 --- a/xla/tests/client_library_test_base.h +++ b/xla/tests/client_library_test_base.h @@ -71,6 +71,9 @@ class ClientLibraryTestBase : public ManifestCheckingTest { ClientLibraryTestBase(se::Platform* platform, const LocalClientOptions& client_options); + // Returns the name of the suite currently being run. + std::string SuiteName() const; + // Returns the name of the test currently being run. std::string TestName() const; diff --git a/xla/tests/exhaustive/BUILD b/xla/tests/exhaustive/BUILD index 89220e327497c1..441019f09e99ac 100644 --- a/xla/tests/exhaustive/BUILD +++ b/xla/tests/exhaustive/BUILD @@ -39,6 +39,8 @@ cc_library( "//xla/client:xla_computation", "//xla/service:shaped_buffer", "//xla/tests:client_library_test_base", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -48,6 +50,8 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:path", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], @@ -84,9 +88,11 @@ xla_test( "//xla/client:xla_builder", "//xla/client/lib:constants", "//xla/client/lib:math", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", ], ) @@ -117,9 +123,11 @@ xla_test( "//xla/client/lib:constants", "//xla/client/lib:math", "//xla/tests:xla_internal_test_main", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", ], ) @@ -148,8 +156,10 @@ xla_test( "//xla/client:xla_builder", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", ], ) @@ -176,10 +186,12 @@ xla_test( "//xla:literal", "//xla/client:xla_builder", "//xla/tests:xla_internal_test_main", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", ], ) @@ -206,9 +218,11 @@ xla_test( "//xla:literal", "//xla/client:xla_builder", "//xla/tests:xla_internal_test_main", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", ], ) diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/xla/tests/exhaustive/exhaustive_op_test_utils.cc index 7e8f8718b454dc..6c791a860f7028 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -30,11 +31,17 @@ limitations under the License. #include "absl/meta/type_traits.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "Eigen/Core" #include "xla/literal.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/types.h" +#include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" +#include "tsl/platform/path.h" #include "tsl/platform/test.h" namespace xla { @@ -44,6 +51,17 @@ int eup_version = 0; int GetEupVersion() { return eup_version; } +bool dump_values = false; + +bool ShouldDumpValues() { return dump_values; } + +void AddExhaustiveFlags(std::vector& flag_list) { + flag_list.push_back( + tsl::Flag("dump_values", &xla::exhaustive_op_test::dump_values, + "Include to dump files of the expected and actual results " + "(default false).")); +} + bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); } bool IsSubnormalReal(xla::complex128 value) { @@ -409,6 +427,30 @@ void ExhaustiveOpTestBase::ExpectNear( } } + // Dump file for the test. This is unused unless this->should_dump_values is + // true. + std::unique_ptr dump_file; + if (should_dump_values_) { + auto* env = tsl::Env::Default(); + + std::string cleaned_suite_name = + absl::StrReplaceAll(SuiteName(), {{"/", "__"}}); + std::string cleaned_test_name = + absl::StrReplaceAll(TestName(), {{"/", "__"}}); + std::string dump_filename = absl::StrFormat( + "%s_%s_dump.txt", cleaned_suite_name, cleaned_test_name); + + std::string outdir; + if (tsl::io::GetTestUndeclaredOutputsDir(&outdir)) { + dump_filename = tsl::io::JoinPath(outdir, dump_filename); + } + + TF_EXPECT_OK(env->NewWritableFile(dump_filename, &dump_file)); + TF_EXPECT_OK( + dump_file->Append("input values -> actual output {expected output}\n" + "-----------------------------------------------\n")); + } + NativeInputsList inputs_arr; for (int i = 0; i < N; ++i) { const Literal& literal = input_literals[i]; @@ -432,6 +474,22 @@ void ExhaustiveOpTestBase::ExpectNear( NativeT actual = result_arr[i]; NativeT expected = static_cast(CallOperation(evaluate_op, inputs_ref_ty)); + + // Dump input, actual, and expected values _before_ we do error checking to + // avoid the continues. + if (should_dump_values_) { + std::string result_string; + absl::StrAppend( + &result_string, + StringifyNum(inputs), " -> ", + StringifyNum(actual)); + absl::StrAppend(&result_string, " {", + StringifyNum(expected), + "}"); + absl::StrAppend(&result_string, "\n"); + TF_EXPECT_OK(dump_file->Append(result_string)); + } + ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs); if (error_spec.skip_comparison) { @@ -535,6 +593,10 @@ void ExhaustiveOpTestBase::ExpectNear( PrintMismatch(&mismatches, [mismatch] { return mismatch; }); } EXPECT_EQ(mismatches, 0); + + if (should_dump_values_) { + TF_EXPECT_OK(dump_file->Close()); + } } template class ExhaustiveOpTestBase; diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.h b/xla/tests/exhaustive/exhaustive_op_test_utils.h index c6aece69a1d027..6a491be43745fc 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -36,6 +36,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "Eigen/Core" #include "xla/bit_cast.h" @@ -48,6 +49,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/shaped_buffer.h" #include "xla/tests/client_library_test_base.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -61,6 +63,12 @@ extern int eup_version; // Get the TPU EUP version (if it was provided). int GetEupVersion(); +// Return if the user specified dumping all tested values with their expected +// and actual results. +bool ShouldDumpValues(); + +void AddExhaustiveFlags(std::vector& flag_list); + // Determines if the real component of the complex number is subnormal (either // sign). // @@ -257,7 +265,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { explicit ExhaustiveOpTestBase() : ty_(T), platform_(client_->platform()->Name()), - eup_version_(xla::exhaustive_op_test::GetEupVersion()) { + eup_version_(xla::exhaustive_op_test::GetEupVersion()), + should_dump_values_(xla::exhaustive_op_test::ShouldDumpValues()) { SetFastMathDisabled(true); // Run all HLO passes. In particular, constant folding is disabled by @@ -295,6 +304,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, RunComputationHelper(comp, input_literals)); + ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen, check_valid_range); } @@ -634,6 +644,9 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // // XLA:GPU preserves denormal signs, but other backends don't. bool relaxed_denormal_signs_ = platform_ != "CUDA"; + + // Indicates if files of the expected and actual values should be dumped. + bool should_dump_values_ = false; }; // Represents a set of 64 bit chunks by representing the starting bit chunk, diff --git a/xla/tests/exhaustive/exhaustive_test_main.cc b/xla/tests/exhaustive/exhaustive_test_main.cc index 70588bc8e8a120..cc1bc9dd5533e8 100644 --- a/xla/tests/exhaustive/exhaustive_test_main.cc +++ b/xla/tests/exhaustive/exhaustive_test_main.cc @@ -18,9 +18,23 @@ limitations under the License. // the --benchmark_filter flag which specifies which benchmarks to run, // we will either run benchmarks or run the gtest tests in the program. +#include +#include + +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" +#include "xla/tsl/util/command_line_flags.h" +#include "tsl/platform/logging.h" #include "tsl/platform/test.h" GTEST_API_ int main(int argc, char** argv) { + std::vector flag_list; + xla::exhaustive_op_test::AddExhaustiveFlags(flag_list); + std::string usage = tsl::Flags::Usage(argv[0], flag_list); + if (!tsl::Flags::Parse(&argc, argv, flag_list)) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/xla/tests/exhaustive/exhaustive_unary_complex_test.cc b/xla/tests/exhaustive/exhaustive_unary_complex_test.cc index b7cc275bfff11e..e34d5cf2a3eb05 100644 --- a/xla/tests/exhaustive/exhaustive_unary_complex_test.cc +++ b/xla/tests/exhaustive/exhaustive_unary_complex_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { +namespace { // T is the Primitive Type of the complex number // Test parameter is a tuple containing @@ -325,5 +326,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn( GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); +} // namespace } // namespace exhaustive_op_test } // namespace xla diff --git a/xla/tests/exhaustive/exhaustive_unary_f64_test.cc b/xla/tests/exhaustive/exhaustive_unary_f64_test.cc index 8e81769afe8ea0..0792b017813416 100644 --- a/xla/tests/exhaustive/exhaustive_unary_f64_test.cc +++ b/xla/tests/exhaustive/exhaustive_unary_f64_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { +namespace { // Exhaustive test for unary operations for double. // @@ -146,5 +147,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals( 4000000000ull, 16000000))); +} // namespace } // namespace exhaustive_op_test } // namespace xla From 8a5332de8ef5286a966cced803b2543e835c93d8 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 31 Jul 2024 13:44:19 -0700 Subject: [PATCH 337/376] Remove GpuStream::platform_specific_stream method in favor of the more widely-used ::gpu_stream method. PiperOrigin-RevId: 658127980 --- xla/stream_executor/cuda/cuda_executor.cc | 22 +++++++++++----------- xla/stream_executor/gpu/gpu_stream.h | 6 ------ xla/stream_executor/rocm/rocm_executor.cc | 16 ++++++++-------- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 3deea1a870578f..957bfe20386fbb 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -665,13 +665,13 @@ void GpuExecutor::DeallocateStream(Stream* stream) { dnn_->NotifyStreamDestroyed(stream); } } - GpuStream* cuda_stream = AsGpuStream(stream); + GpuStream* gpu_stream = AsGpuStream(stream); absl::MutexLock l(&alive_gpu_streams_mu_); - alive_gpu_streams_.erase(cuda_stream->platform_specific_stream()); - if (!cuda_stream->IsIdle()) { + alive_gpu_streams_.erase(gpu_stream->gpu_stream()); + if (!gpu_stream->IsIdle()) { LOG(ERROR) << "Deallocating stream with pending work"; } - cuda_stream->Destroy(); + gpu_stream->Destroy(); } absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { @@ -805,20 +805,20 @@ absl::StatusOr> GpuExecutor::CreateEvent() { absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { - auto gpu_stream = std::make_unique(this); + auto stream = std::make_unique(this); if (priority.has_value()) { if (std::holds_alternative(*priority)) { - gpu_stream->SetPriority(std::get(*priority)); + stream->SetPriority(std::get(*priority)); } else { - gpu_stream->SetPriority(std::get(*priority)); + stream->SetPriority(std::get(*priority)); } } absl::MutexLock l(&alive_gpu_streams_mu_); - bool init_worked = gpu_stream->Init(); + bool init_worked = stream->Init(); if (init_worked) { - auto platform_specific_stream = gpu_stream->platform_specific_stream(); - alive_gpu_streams_[platform_specific_stream] = gpu_stream.get(); - return std::move(gpu_stream); + auto gpu_stream = stream->gpu_stream(); + alive_gpu_streams_[gpu_stream] = stream.get(); + return std::move(stream); } else { return absl::InvalidArgumentError("Failed to initialize gpu stream"); } diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index 18b77fb888481b..bbd8464c4373e5 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -58,12 +58,6 @@ class GpuStream : public StreamCommon { parent()->DeallocateStream(this); } - // Returns a pointer to a platform specific stream associated with this object - // if it exists, or nullptr otherwise. This is available via Stream public API - // as Stream::PlatformSpecificHandle, and should not be accessed directly - // outside of a StreamExecutor package. - void* platform_specific_stream() const { return gpu_stream_; } - // Explicitly initialize the CUDA resources associated with this stream. bool Init(); diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 366d2da583de31..8f3d28228a9b86 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -499,7 +499,7 @@ void GpuExecutor::DeallocateStream(Stream* stream) { } GpuStream* rocm_stream = AsGpuStream(stream); absl::MutexLock l(&alive_gpu_streams_mu_); - alive_gpu_streams_.erase(rocm_stream->platform_specific_stream()); + alive_gpu_streams_.erase(rocm_stream->gpu_stream()); if (!rocm_stream->IsIdle()) { LOG(ERROR) << "Deallocating stream with pending work"; } @@ -642,20 +642,20 @@ absl::StatusOr> GpuExecutor::CreateEvent() { absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { - auto gpu_stream = std::make_unique(this); + auto stream = std::make_unique(this); if (priority.has_value()) { if (std::holds_alternative(*priority)) { - gpu_stream->SetPriority(std::get(*priority)); + stream->SetPriority(std::get(*priority)); } else { - gpu_stream->SetPriority(std::get(*priority)); + stream->SetPriority(std::get(*priority)); } } absl::MutexLock l(&alive_gpu_streams_mu_); - bool init_worked = gpu_stream->Init(); + bool init_worked = stream->Init(); if (init_worked) { - auto platform_specific_stream = gpu_stream->platform_specific_stream(); - alive_gpu_streams_[platform_specific_stream] = gpu_stream.get(); - return std::move(gpu_stream); + auto gpu_stream = stream->gpu_stream(); + alive_gpu_streams_[gpu_stream] = stream.get(); + return std::move(stream); } else { return absl::InvalidArgumentError("Failed to initialize GPU stream"); } From a3bc38ac798daf0e5999dc98982d97dd1a84379f Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Wed, 31 Jul 2024 13:51:46 -0700 Subject: [PATCH 338/376] Remove the check for PJRT C API version in GetCompiledMemoryStats. PiperOrigin-RevId: 658130638 --- xla/pjrt/c/pjrt_c_api_helpers.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index fcfc9119e9ec79..b9508cf24950b4 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -994,13 +994,6 @@ absl::Span DeviceDescriptions( absl::StatusOr GetCompiledMemoryStats( const PJRT_Api* api, PJRT_Executable* executable) { - // TODO(jieying): To be removed after 03/2024. - if (api->pjrt_api_version.major_version == 0 && - api->pjrt_api_version.minor_version < 40) { - return absl::UnimplementedError( - "GetCompiledMemoryStats requires a plugin with PJRT C API version >= " - "0.40"); - } PJRT_Executable_GetCompiledMemoryStats_Args args; args.struct_size = PJRT_Executable_GetCompiledMemoryStats_Args_STRUCT_SIZE; args.extension_start = nullptr; From 595c6b278c3cac3f2fd878a0ead8a3fec7048cbd Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 31 Jul 2024 14:18:30 -0700 Subject: [PATCH 339/376] [xla:cpu] Optimize ThunkExecutor::Execute part #2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use std::aligned_storage_t trick to avoid default-initializing Node struct on a hot path. name old cpu/op new cpu/op delta BM_SelectAndScatterF32/128/process_time 791µs ± 4% 720µs ± 2% -8.93% BM_SelectAndScatterF32/256/process_time 3.20ms ± 4% 2.96ms ± 2% -7.46% BM_SelectAndScatterF32/512/process_time 13.7ms ± 5% 12.8ms ± 2% -6.80% name old time/op new time/op delta BM_SelectAndScatterF32/128/process_time 790µs ± 5% 719µs ± 1% -9.00% BM_SelectAndScatterF32/256/process_time 3.20ms ± 3% 2.96ms ± 1% -7.58% BM_SelectAndScatterF32/512/process_time 13.2ms ± 4% 12.3ms ± 1% -6.82% PiperOrigin-RevId: 658139935 --- .bazelrc | 7 +++++++ third_party/tsl/.bazelrc | 7 +++++++ xla/service/cpu/runtime/thunk_executor.cc | 13 +++++++------ xla/service/cpu/runtime/thunk_executor.h | 14 +++++++++++++- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/.bazelrc b/.bazelrc index b94693e05efab8..11783a8012ddb2 100644 --- a/.bazelrc +++ b/.bazelrc @@ -351,6 +351,13 @@ build:windows --features=archive_param_file build:windows --copt=/d2ReducedOptimizeHugeFunctions build:windows --host_copt=/d2ReducedOptimizeHugeFunctions +# Before VS 2017 15.8, the member "type" would non-conformingly have an +# alignment of only alignof(max_align_t). VS 2017 15.8 was fixed to handle this +# correctly, but the fix inherently changes layout and breaks binary +# compatibility (*only* for uses of aligned_storage with extended alignments). +build:windows --copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE +build:windows --host_copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE + # Enable the runfiles symlink tree on Windows. This makes it possible to build # the pip package on Windows without an intermediate data-file archive, as the # build_pip_package script in its current form (as of Aug 2023) uses the diff --git a/third_party/tsl/.bazelrc b/third_party/tsl/.bazelrc index b94693e05efab8..11783a8012ddb2 100644 --- a/third_party/tsl/.bazelrc +++ b/third_party/tsl/.bazelrc @@ -351,6 +351,13 @@ build:windows --features=archive_param_file build:windows --copt=/d2ReducedOptimizeHugeFunctions build:windows --host_copt=/d2ReducedOptimizeHugeFunctions +# Before VS 2017 15.8, the member "type" would non-conformingly have an +# alignment of only alignof(max_align_t). VS 2017 15.8 was fixed to handle this +# correctly, but the fix inherently changes layout and breaks binary +# compatibility (*only* for uses of aligned_storage with extended alignments). +build:windows --copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE +build:windows --host_copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE + # Enable the runfiles symlink tree on Windows. This makes it possible to build # the pip package on Windows without an intermediate data-file archive, as the # build_pip_package script in its current form (as of Aug 2023) uses the diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index f25fd6119a284d..805840ad855e93 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -122,6 +122,9 @@ absl::StatusOr ThunkExecutor::Create( return ThunkExecutor(std::move(thunk_sequence), std::move(defs), options); } +ThunkExecutor::ExecuteState::Node::Node(const NodeDef& node_def) + : counter(node_def.in_edges.size()), out_edges(&node_def.out_edges) {} + ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor, Thunk::TaskRunner* runner) : executor(executor), @@ -133,11 +136,9 @@ ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor, DCHECK(runner == nullptr || static_cast(*runner)) << "`runner` must be nullptr or a valid TaskRunner"; - Node* node = nodes.data(); + NodeStorage* node = nodes.data(); for (const NodeDef& node_def : executor->nodes_defs()) { - node->counter.store(node_def.in_edges.size(), std::memory_order_release); - node->out_edges = &node_def.out_edges; - ++node; + new (node++) Node(node_def); } } @@ -271,7 +272,7 @@ void ThunkExecutor::Execute(ExecuteState* state, for (int64_t i = 0; i < ready_queue.size(); ++i) { NodeId id = ready_queue[i]; - ExecuteState::Node& node = state->nodes[id]; + ExecuteState::Node& node = state->node(id); int64_t cnt = node.counter.load(std::memory_order_acquire); DCHECK_EQ(cnt, 0) << "Node counter must be 0"; // Crash Ok @@ -375,7 +376,7 @@ void ThunkExecutor::ProcessOutEdges( // Append ready nodes to the back of the ready queue. for (NodeId out_edge : *node.out_edges) { - ExecuteState::Node& out_node = state->nodes[out_edge]; + ExecuteState::Node& out_node = state->node(out_edge); int64_t cnt = out_node.counter.fetch_sub(1, std::memory_order_release); DCHECK_GE(cnt, 1) << "Node counter can't drop below 0"; diff --git a/xla/service/cpu/runtime/thunk_executor.h b/xla/service/cpu/runtime/thunk_executor.h index 67a66c422bf5c6..10df02c45a9383 100644 --- a/xla/service/cpu/runtime/thunk_executor.h +++ b/xla/service/cpu/runtime/thunk_executor.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/thread_annotations.h" @@ -113,16 +114,27 @@ class ThunkExecutor { // At run time NodeDef instantiated as a Node with an atomic counter that // drops to zero when all `in_edges` are ready. struct Node { + explicit Node(const NodeDef& node_def); + alignas(kAtomicAlignment) std::atomic counter; const std::vector* out_edges; }; + static_assert(std::is_trivially_destructible_v, + "Node must be trivially destructible"); + + // We use indirection via NodeStorage to be able to allocate uninitialized + // memory and do not pay the cost of default initializing all nodes. + using NodeStorage = std::aligned_storage_t; + ExecuteState(ThunkExecutor* executor, Thunk::TaskRunner* runner); + Node& node(NodeId id) { return *reinterpret_cast(&nodes[id]); } + ThunkExecutor* executor; Thunk::TaskRunner* runner; - absl::FixedArray nodes; + absl::FixedArray nodes; tsl::AsyncValueRef execute_event; // Once the number of pending sink nodes drops to zero, the execution is From 331d668cf959e8fab53d328ede7ff9b65bd0e1ec Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 31 Jul 2024 14:34:48 -0700 Subject: [PATCH 340/376] [xla:cpu] Switch XLA:CPU runtime to thunks interpreter With this change XLA:CPU instead of compiling one LLVM function for the whole HLO module compiles separate functions for different fusions and runs them via the interpreter-like runtime. This can change numerics because of slightly different LLVM IR and missed cross-fusion optimizations. If this breaks your tests, they likely have to relax numerical error tolerance. Another potential issue is performance regressions for while loops with large number of iterations and small computation, as instead of compiling, we run such loops in interpreter. We plan to fix it in the future. To disable thunks runtime set env variable: XLA_FLAGS=--xla_cpu_use_thunk_runtime=false. PiperOrigin-RevId: 658145258 --- xla/debug_options_flags.cc | 2 +- xla/service/cpu/tests/onednn_convolution_test.cc | 6 ++++++ xla/service/cpu/tests/onednn_layer_norm_test.cc | 8 +++++++- xla/service/cpu/tests/onednn_matmul_test.cc | 16 +++++++++++----- xla/service/cpu/tests/onednn_softmax_test.cc | 6 ++++++ 5 files changed, 31 insertions(+), 7 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 28ab8fa6e390c5..79e716b7aca4bb 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -82,7 +82,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { #ifdef XLA_CPU_USE_ACL opts.set_xla_cpu_use_acl(true); #endif - opts.set_xla_cpu_use_thunk_runtime(false); + opts.set_xla_cpu_use_thunk_runtime(true); opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false); opts.set_xla_cpu_prefer_vector_width(256); diff --git a/xla/service/cpu/tests/onednn_convolution_test.cc b/xla/service/cpu/tests/onednn_convolution_test.cc index 6428e31c1d2fbb..50c0e8f4f645e7 100644 --- a/xla/service/cpu/tests/onednn_convolution_test.cc +++ b/xla/service/cpu/tests/onednn_convolution_test.cc @@ -34,6 +34,12 @@ namespace cpu { class ConvolutionTest : public HloTestBase { protected: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_use_thunk_runtime(false); + return debug_options; + } + const char* conv_rewrite_str_ = R"( ; CHECK: custom_call_target="__onednn$convolution", ; CHECK: backend_config={ diff --git a/xla/service/cpu/tests/onednn_layer_norm_test.cc b/xla/service/cpu/tests/onednn_layer_norm_test.cc index 39913542d2d0ee..92ca5061724faf 100644 --- a/xla/service/cpu/tests/onednn_layer_norm_test.cc +++ b/xla/service/cpu/tests/onednn_layer_norm_test.cc @@ -24,6 +24,12 @@ namespace { class LayerNormTest : public HloTestBase { protected: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_use_thunk_runtime(false); + return debug_options; + } + const char* onednn_layer_norm_ = R"( ; CHECK: custom_call_target="__onednn$layernorm", @@ -95,7 +101,7 @@ TEST_F(LayerNormTest, LayerNormTest0_FP32) { common_hlo_region_ + R"( ENTRY main { Arg_0.1 = f32[84,197,768]{2,1,0} parameter(0), sharding={replicated} - + )" + common_hlo_entry_computation_block_ + R"( ROOT add.338 = f32[84,197,768]{2,1,0} add(multiply.331, subtract.337) diff --git a/xla/service/cpu/tests/onednn_matmul_test.cc b/xla/service/cpu/tests/onednn_matmul_test.cc index c31ed5c2fc6ac7..d7fb39f0d33a90 100644 --- a/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/xla/service/cpu/tests/onednn_matmul_test.cc @@ -36,6 +36,12 @@ namespace cpu { class MatmulTest : public HloTestBase { protected: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_use_thunk_runtime(false); + return debug_options; + } + const char* fused_matmul_bias_ = R"( ; CHECK: custom_call_target="__onednn$matmul", ; CHECK: backend_config={ @@ -225,7 +231,7 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion1) { TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion2) { const char* matmul_module_str = R"( HloModule matmul.biasadd.test.f32 - + ENTRY matmul.biasadd.test.f32 { arg0.1 = f32[400,300] parameter(0), parameter_replication={false} reshape.2 = f32[400,300] reshape(arg0.1) @@ -1128,7 +1134,7 @@ TEST_F(MatmulTest, SIGMOIDTestF32) { const.0 = f32[32]{0} constant(5) bcast.0 = f32[32,32,4,32] broadcast(const.0), dimensions={3} add.0 = f32[32,32,4,32] add(onednn.matmul.0, bcast.0) - + const.1 = f32[] constant(1) bcast.1 = f32[32,32,4,32] broadcast(const.1), dimensions={} negate.0 = f32[32,32,4,32] negate(add.0) @@ -1149,7 +1155,7 @@ TEST_F(MatmulTest, SIGMOIDTestBF16) { } const char* matmul_module_str = R"( HloModule matmul.bias.sigmoid.test.bf16 - + ENTRY matmul.bias.sigmoid.test.bf16 { arg.0 = f32[32,32,4,16] parameter(0), parameter_replication={false} convert.0 = bf16[32,32,4,16] convert(arg.0) @@ -1180,7 +1186,7 @@ TEST_F(MatmulTest, SIGMOIDTestF16) { } const char* matmul_module_str = R"( HloModule matmul.bias.sigmoid.test.f16 - + ENTRY matmul.bias.sigmoid.test.f16 { arg.0 = f32[32,32,4,16] parameter(0), parameter_replication={false} convert.0 = f16[32,32,4,16] convert(arg.0) @@ -1230,7 +1236,7 @@ TEST_F(MatmulTest, SimpleTestBF16Gemv2) { const char* matmul_module_str = R"( HloModule matmul.test.bf16 - + ENTRY matmul.test.bf16 { arg.0 = bf16[100,300,300] parameter(0) arg.1 = bf16[300] parameter(1) diff --git a/xla/service/cpu/tests/onednn_softmax_test.cc b/xla/service/cpu/tests/onednn_softmax_test.cc index b7e43731cfb673..1fff5d88a736e5 100644 --- a/xla/service/cpu/tests/onednn_softmax_test.cc +++ b/xla/service/cpu/tests/onednn_softmax_test.cc @@ -47,6 +47,12 @@ class OneDnnSoftmaxTest : public HloTestBase, public ::testing::WithParamInterface> { protected: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_use_thunk_runtime(false); + return debug_options; + } + const char* onednn_softmax_ = R"( ; CHECK: custom_call_target="__onednn$softmax" From 4e00ababef0157a38a077f7c829025629ceb3be5 Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Wed, 31 Jul 2024 14:37:49 -0700 Subject: [PATCH 341/376] Add function to print a compact 2D map of occupied heap memory vs time as ASCII art for easier debugging. PiperOrigin-RevId: 658146172 --- xla/service/heap_simulator/BUILD | 14 +- xla/service/heap_simulator/heap_simulator.cc | 148 +++++++++++++++++- xla/service/heap_simulator/heap_simulator.h | 51 +++++- .../heap_simulator/heap_simulator_test.cc | 50 +++++- 4 files changed, 237 insertions(+), 26 deletions(-) diff --git a/xla/service/heap_simulator/BUILD b/xla/service/heap_simulator/BUILD index e9873e948e0989..2a5e0666799ceb 100644 --- a/xla/service/heap_simulator/BUILD +++ b/xla/service/heap_simulator/BUILD @@ -42,21 +42,18 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", - "//xla/service:buffer_value_containers", "//xla/service:hlo_alias_analysis", - "//xla/service:hlo_buffer", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_ordering", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", + "//xla/service:logical_buffer", "//xla/service:time_utils", - "//xla/service:tuple_points_to_analysis", - "//xla/service/memory_space_assignment:repacking", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -71,19 +68,14 @@ xla_cc_test( deps = [ ":allocation_block", ":heap_simulator", - "//xla:literal", - "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/service:buffer_value", - "//xla/service:hlo_ordering", "//xla/service:hlo_parser", "//xla/service:hlo_value", - "//xla/service:tuple_points_to_analysis", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:test", diff --git a/xla/service/heap_simulator/heap_simulator.cc b/xla/service/heap_simulator/heap_simulator.cc index a45528e6c1ac2f..fc319e681f769a 100644 --- a/xla/service/heap_simulator/heap_simulator.cc +++ b/xla/service/heap_simulator/heap_simulator.cc @@ -34,8 +34,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -46,11 +48,110 @@ limitations under the License. #include "xla/map_util.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/hlo_value.h" -#include "xla/service/memory_space_assignment/repacking.h" #include "xla/service/time_utils.h" #include "xla/util.h" namespace xla { +namespace { + +constexpr int64_t kMaxMemoryMapDimensionSize = 100; + +struct AsciiMemoryMapParameters { + int64_t memory_block_size = 1; + int64_t end_of_last_occupied_chunk = -1; +}; + +// Given a set of BufferIntervalTreeNodes, returns the best memory block size(to +// visually represent all chunks in a compact fashion) and the maximum chunk end +// of all occupied chunks. The best memory block size is the greatest common +// divisor of all chunk offsets and chunk ends. These are parameters required to +// construct a compact memory map. +AsciiMemoryMapParameters GetAsciiMemoryMapParameters( + std::vector& nodes) { + CHECK(!nodes.empty()); + int64_t min_chunk_offset = std::numeric_limits::max(); + int64_t end_of_last_occupied_chunk = -1; + int64_t memory_block_size = nodes.front()->chunk.offset; + for (const BufferIntervalTreeNode* node : nodes) { + min_chunk_offset = std::min(min_chunk_offset, node->chunk.offset); + end_of_last_occupied_chunk = + std::max(end_of_last_occupied_chunk, node->chunk.chunk_end()); + memory_block_size = std::gcd(memory_block_size, node->chunk.offset); + memory_block_size = std::gcd(memory_block_size, node->chunk.chunk_end()); + } + VLOG(3) << " min_chunk_offset: " << min_chunk_offset + << " end_of_last_occupied_chunk: " << end_of_last_occupied_chunk + << " memory_block_size: " << memory_block_size; + return {memory_block_size, end_of_last_occupied_chunk}; +} + +// Returns a memory map for the given time interval [start, end]. +// The memory map is a 2D array of size [n, m], where n is the number of memory +// blocks and m is the number of time steps. Each row represents a memory block +// and each column represents a time step. The value at (i, j) indicates whether +// there is a buffer occupying the entire memory block at time j. +std::vector> GetMemoryMap( + int64_t start, int64_t end, int64_t memory_block_size, + int64_t num_memory_blocks, + std::vector& nodes) { + int64_t total_time = end - start + 1; + std::vector> memory_map( + num_memory_blocks, std::vector(total_time, false)); + for (const BufferIntervalTreeNode* node : nodes) { + for (int64_t i = node->chunk.offset / memory_block_size; + i < node->chunk.chunk_end() / memory_block_size; ++i) { + for (int64_t j = std::max(node->start - start, int64_t{0}); + j <= std::min(node->end - start, end - start); ++j) { + memory_map[i][j] = true; + } + } + } + return memory_map; +} + +// Given a list of BufferIntervalTreeNodes, returns a string representation of +// the nodes. +std::string BufferIntervalTreeNodesToString( + absl::Span nodes) { + std::string output; + for (const BufferIntervalTreeNode* node : nodes) { + absl::StrAppend(&output, node->ToString(), "\n"); + } + return output; +} + +// Returns a string representation of the memory map of occupied memory blocks +// for the given time interval [start, end]. +std::string MemoryMapToString(int64_t start, int64_t end, + int64_t memory_block_size, int64_t group_size, + std::vector>& memory_map) { + int64_t num_memory_blocks = memory_map.size(); + int64_t total_time = memory_map.front().size(); + std::string output = "\n"; + absl::StrAppend(&output, "Memory map for time: [", start, ",", end, + "], memory_block_size: ", memory_block_size, + ", group_size: ", group_size, "\n\n"); + for (int64_t i = num_memory_blocks - 1; i >= 0; --i) { + for (int64_t j = 0; j < total_time; ++j) { + if (group_size && j % group_size == 0) { + absl::StrAppend(&output, " "); + } + absl::StrAppend(&output, memory_map[i][j] ? "#" : "."); + } + absl::StrAppend(&output, " ", std::to_string((i + 1) * memory_block_size), + "\n"); + } + for (int64_t j = start; j <= end; ++j) { + if (group_size && j % group_size == 0) { + absl::StrAppend(&output, " "); + } + absl::StrAppend(&output, std::to_string(j % 10)); + } + absl::StrAppend(&output, "\n\n"); + return output; +} + +} // namespace using absl::flat_hash_map; using absl::flat_hash_set; @@ -73,6 +174,11 @@ std::string HeapSimulator::Chunk::ToString() const { return absl::StrCat("[", offset, ",", chunk_end(), ")"); } +std::string BufferIntervalTreeNode::ToString() const { + return absl::StrCat("start: ", start, " end: ", end, + " chunk: ", chunk.ToString()); +} + bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const { CHECK_NE(size, 0); CHECK_NE(other_chunk.size, 0); @@ -848,6 +954,16 @@ bool BufferIntervalTree::Remove(int64_t start, int64_t end, std::vector BufferIntervalTree::ChunksOverlappingInTime( int64_t start, int64_t end) const { std::vector result; + for (const BufferIntervalTreeNode* node : + NodesOverlappingInTime(start, end)) { + result.push_back(node->chunk); + } + return result; +} + +std::vector +BufferIntervalTree::NodesOverlappingInTime(int64_t start, int64_t end) const { + std::vector result; if (root_ == nullptr) { return result; } @@ -863,7 +979,7 @@ std::vector BufferIntervalTree::ChunksOverlappingInTime( visiting_stack.push_back(top->left); } if (top->start <= end && top->end >= start) { - result.push_back(top->chunk); + result.push_back(top); } if (end < top->start) { continue; @@ -875,6 +991,34 @@ std::vector BufferIntervalTree::ChunksOverlappingInTime( return result; } +std::string BufferIntervalTree::NodesOverlappingInTimeToAsciiArt( + int64_t start, int64_t end, int64_t group_size) const { + std::vector nodes = + NodesOverlappingInTime(start, end); + if (nodes.empty()) { + return "No nodes overlapping in time. Memory is free!"; + } + auto [memory_block_size, end_of_last_occupied_chunk] = + GetAsciiMemoryMapParameters(nodes); + CHECK_GE(end_of_last_occupied_chunk, 0); + CHECK_NE(memory_block_size, 0); + int64_t total_time = end - start + 1; + int64_t num_memory_blocks = end_of_last_occupied_chunk / memory_block_size; + if (total_time > kMaxMemoryMapDimensionSize || + num_memory_blocks > kMaxMemoryMapDimensionSize) { + std::string output; + absl::StrAppend( + &output, + "\nCannot print memory usage to ASCII art. Printing nodes instead!\n\n", + BufferIntervalTreeNodesToString(nodes)); + return output; + } + std::vector> memory_map = + GetMemoryMap(start, end, memory_block_size, num_memory_blocks, nodes); + return MemoryMapToString(start, end, memory_block_size, group_size, + memory_map); +} + template std::string GlobalDecreasingSizeBestFitHeap::BufferInterval::ToString() const { diff --git a/xla/service/heap_simulator/heap_simulator.h b/xla/service/heap_simulator/heap_simulator.h index dfa62f018ae133..09e12d2aca7042 100644 --- a/xla/service/heap_simulator/heap_simulator.h +++ b/xla/service/heap_simulator/heap_simulator.h @@ -24,7 +24,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -43,16 +42,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" -#include "xla/service/buffer_value_containers.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" -#include "xla/service/memory_space_assignment/repacking.h" -#include "xla/service/tuple_points_to_analysis.h" +#include "xla/service/logical_buffer.h" namespace xla { @@ -364,6 +358,8 @@ struct BufferIntervalTreeNode { BufferIntervalTreeNode* right; // parent BufferIntervalTreeNode* parent; + + std::string ToString() const; }; // An interval tree that can query buffers overlapping in time. @@ -383,7 +379,48 @@ class BufferIntervalTree { BufferIntervalTreeNode* GetRoot() { return root_; } + // Returns a compact 2D view of memory usage over time. + // X axis is time, Y axis is memory. + // + // Say there are 3 buffers in the heap: + // - Buffer 1: memory block [0, 16), time interval [15, 25] + // - Buffer 2: memory block [16, 48), time interval [15, 19] + // - Buffer 3: memory block [32, 64), time interval [20, 22] + // + // NodesOverlappingInTimeToAsciiArt(/*start=*/18, /*end=*/23, + // /*group_size=*/3) returns: + // + // Memory map for time: [18,23], memory_block_size: 16, group_size: 3 + // + // ..# ##. 64 + // ### ##. 48 + // ##. ... 32 + // ### ### 16 + // 890 123 + // + // Explanation: + // + // The functions decides a memory block size of 16 would be most compact to + // display all the buffers. + // '#' indicates used and '.' indicates free memory. + // + // ..# ##. 64 "64" indicates memory block [48,64) + // ### ##. 48 "48" indicates memory block [32,48) + // ##. ... 32 "32" indicates memory block [16,32) + // ### ### 16 "16" indicates memory block [0,16) + // 890 123 + // + // "890 123" indicate the last digits of time instants 18, 19, 20, 21, 22, 23. + // Only the last digit is shown for compactness. + // `group_size=3` inserts spaces after every 3 columns (time instants). + // All the memory blocks beyond 64 are free for time interval [18,23]. + std::string NodesOverlappingInTimeToAsciiArt(int64_t start, int64_t end, + int64_t group_size = 0) const; + private: + std::vector NodesOverlappingInTime( + int64_t start, int64_t end) const; + BufferIntervalTreeNode* root_ = nullptr; std::list node_storage_; }; diff --git a/xla/service/heap_simulator/heap_simulator_test.cc b/xla/service/heap_simulator/heap_simulator_test.cc index 27f0261d103b66..cff0e2f3a72547 100644 --- a/xla/service/heap_simulator/heap_simulator_test.cc +++ b/xla/service/heap_simulator/heap_simulator_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -26,19 +25,14 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/literal.h" #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_value.h" -#include "xla/service/tuple_points_to_analysis.h" -#include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" @@ -47,6 +41,9 @@ limitations under the License. namespace xla { namespace { +using ::testing::HasSubstr; +using ::testing::StrEq; + class MinimumMemoryForSequenceTest : public HloTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { @@ -2019,6 +2016,47 @@ TEST_F(IntervalTreeTest, ThreeLevelsRightLeftChunkDifferent) { ASSERT_EQ(tree.GetRoot(), nullptr); } +TEST_F(IntervalTreeTest, BufferIntervalTreeToAsciiArt) { + // Buffer 1: memory block [0, 16), time interval [15, 25] + // Buffer 2: memory block [16, 48), time interval [15, 19] + // Buffer 3: memory block [32, 64), time interval [20, 22] + BufferIntervalTree tree; + tree.Add(15, 25, HeapSimulator::Chunk::FromOffsetEnd(0, 16)); + tree.Add(15, 19, HeapSimulator::Chunk::FromOffsetEnd(16, 48)); + tree.Add(20, 22, HeapSimulator::Chunk::FromOffsetEnd(32, 64)); + std::string output = tree.NodesOverlappingInTimeToAsciiArt( + /*start=*/18, /*end=*/23, /*group_size=*/3); + EXPECT_THAT(output, HasSubstr("Memory map for time: [18,23], " + "memory_block_size: 16, group_size: 3")); + EXPECT_THAT(output, HasSubstr("..# ##. 64")); + EXPECT_THAT(output, HasSubstr("### ##. 48")); + EXPECT_THAT(output, HasSubstr("##. ... 32")); + EXPECT_THAT(output, HasSubstr("### ### 16")); + EXPECT_THAT(output, HasSubstr("890 123")); +} + +TEST_F(IntervalTreeTest, BufferIntervalTreeToAsciiArtTooLarge) { + BufferIntervalTree tree; + tree.Add(0, 4, HeapSimulator::Chunk::FromOffsetEnd(0, 128)); + tree.Add(5, 10, HeapSimulator::Chunk::FromOffsetEnd(1, 129)); + std::string output = tree.NodesOverlappingInTimeToAsciiArt( + /*start=*/0, /*end=*/10, /*group_size=*/3); + EXPECT_THAT( + output, + HasSubstr( + "Cannot print memory usage to ASCII art. Printing nodes instead!")); + EXPECT_THAT(output, HasSubstr("start: 0 end: 4 chunk: [0,128)")); + EXPECT_THAT(output, HasSubstr("start: 5 end: 10 chunk: [1,129)")); +} + +TEST_F(IntervalTreeTest, BufferIntervalTreeToAsciiArtFreeMemory) { + BufferIntervalTree tree; + tree.Add(5, 10, HeapSimulator::Chunk::FromOffsetEnd(0, 16)); + std::string output = tree.NodesOverlappingInTimeToAsciiArt( + /*start=*/0, /*end=*/4, /*group_size=*/10); + EXPECT_THAT(output, StrEq("No nodes overlapping in time. Memory is free!")); +} + class SlicedBufferIntervalTest : public ::testing::Test { public: using HeapTy = GlobalDecreasingSizeBestFitHeap; From 5c14e4aac7baf943f908865f16f8af05a2d27547 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 31 Jul 2024 15:18:41 -0700 Subject: [PATCH 342/376] Use GpuEvent in GpuStream rather than re-implementing much of its functionality in GpuStream. PiperOrigin-RevId: 658159374 --- xla/stream_executor/cuda/cuda_executor.cc | 3 ++- xla/stream_executor/gpu/BUILD | 1 + xla/stream_executor/gpu/gpu_stream.cc | 27 +++++++---------------- xla/stream_executor/gpu/gpu_stream.h | 12 +++++----- xla/stream_executor/rocm/rocm_executor.cc | 3 ++- 5 files changed, 19 insertions(+), 27 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 957bfe20386fbb..e43b419f1f9207 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -805,7 +805,8 @@ absl::StatusOr> GpuExecutor::CreateEvent() { absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { - auto stream = std::make_unique(this); + TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); + auto stream = std::make_unique(this, std::move(event)); if (priority.has_value()) { if (std::holds_alternative(*priority)) { stream->SetPriority(std::get(*priority)); diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 24def7c5e53eaf..bee21d83ac8541 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -309,6 +309,7 @@ gpu_only_cc_library( name = "gpu_stream_header", hdrs = ["gpu_stream.h"], deps = [ + ":gpu_event_header", ":gpu_executor_header", ":gpu_types_header", "//xla/stream_executor:device_memory", diff --git a/xla/stream_executor/gpu/gpu_stream.cc b/xla/stream_executor/gpu/gpu_stream.cc index 040e209441ac51..1f75dfac7ebaf1 100644 --- a/xla/stream_executor/gpu/gpu_stream.cc +++ b/xla/stream_executor/gpu/gpu_stream.cc @@ -60,9 +60,7 @@ bool GpuStream::Init() { priority)) { return false; } - return GpuDriver::InitEvent(parent_->gpu_context(), &completed_event_, - GpuDriver::EventFlags::kDisableTiming) - .ok(); + return true; } Stream::PlatformSpecificHandle GpuStream::platform_specific_handle() const { @@ -133,14 +131,12 @@ absl::Status GpuStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, absl::Status GpuStream::WaitFor(Stream* other) { GpuStream* other_gpu = AsGpuStream(other); - GpuEventHandle other_completed_event = *(other_gpu->completed_event()); - TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent_->gpu_context(), - other_completed_event, - AsGpuStreamValue(other_gpu))); - - if (GpuDriver::WaitStreamOnEvent(parent_->gpu_context(), - AsGpuStreamValue(this), - other_completed_event)) { + + GpuEvent* other_completed_event = other_gpu->completed_event(); + TF_RETURN_IF_ERROR(other_completed_event->Record(other_gpu->gpu_stream())); + + if (GpuDriver::WaitStreamOnEvent(parent_->gpu_context(), gpu_stream(), + other_completed_event->gpu_event())) { return absl::OkStatus(); } return absl::InternalError("Couldn't wait for stream."); @@ -177,14 +173,7 @@ absl::Status GpuStream::DoHostCallbackWithStatus( } void GpuStream::Destroy() { - if (completed_event_ != nullptr) { - absl::Status status = - GpuDriver::DestroyEvent(parent_->gpu_context(), &completed_event_); - if (!status.ok()) { - LOG(ERROR) << status.message(); - } - } - + completed_event_.reset(); GpuDriver::DestroyStream(parent_->gpu_context(), &gpu_stream_); } diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index bbd8464c4373e5..3b1cc3868172cb 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include "absl/functional/any_invocable.h" @@ -29,6 +30,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform.h" @@ -46,11 +48,11 @@ class GpuExecutor; // Thread-safe post-initialization. class GpuStream : public StreamCommon { public: - explicit GpuStream(GpuExecutor* parent) + GpuStream(GpuExecutor* parent, std::unique_ptr completed_event) : StreamCommon(parent), parent_(parent), gpu_stream_(nullptr), - completed_event_(nullptr) {} + completed_event_(std::move(completed_event)) {} // Note: teardown is handled by a parent's call to DeallocateStream. ~GpuStream() override { @@ -80,7 +82,7 @@ class GpuStream : public StreamCommon { // Retrieves an event which indicates that all work enqueued into the stream // has completed. Ownership of the event is not transferred to the caller, the // event is owned by this stream. - GpuEventHandle* completed_event() { return &completed_event_; } + GpuEvent* completed_event() { return completed_event_.get(); } // Returns the GpuStreamHandle value for passing to the CUDA API. // @@ -114,9 +116,7 @@ class GpuStream : public StreamCommon { GpuExecutor* parent_; // Executor that spawned this stream. GpuStreamHandle gpu_stream_; // Wrapped CUDA stream handle. std::variant stream_priority_; - - // Event that indicates this stream has completed. - GpuEventHandle completed_event_ = nullptr; + std::unique_ptr completed_event_; }; // Helper functions to simplify extremely common flows. diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 8f3d28228a9b86..c0da3c6a9bd0f4 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -642,7 +642,8 @@ absl::StatusOr> GpuExecutor::CreateEvent() { absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { - auto stream = std::make_unique(this); + TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); + auto stream = std::make_unique(this, std::move(event)); if (priority.has_value()) { if (std::holds_alternative(*priority)) { stream->SetPriority(std::get(*priority)); From 0182d6781b09bb6846432d663bb0dc1db7893d11 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 31 Jul 2024 15:26:04 -0700 Subject: [PATCH 343/376] Export entry parameter layout tiles PiperOrigin-RevId: 658161689 --- xla/translate/hlo_to_mhlo/BUILD | 5 +- .../hlo_to_mhlo/hlo_function_importer.cc | 95 +------------------ .../hlo_to_mhlo/hlo_module_importer.cc | 44 +++++++++ xla/translate/hlo_to_mhlo/hlo_utils.h | 60 +++++++++++- .../tests/entry_computation_layout.hlo | 30 +++--- 5 files changed, 124 insertions(+), 110 deletions(-) diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD index 043dbf0cc801d7..141dad7ed5618d 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD @@ -86,6 +86,7 @@ cc_library( "//xla/service:hlo_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", @@ -127,14 +128,14 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", "//xla/mlir/utils:type_util", "//xla/mlir_hlo", - "//xla/mlir_hlo:convert_op_folder", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:SparseTensorEnums", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index 7d2a929822b1fb..64e08dfae789db 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -27,7 +27,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/ADT/APInt.h" @@ -36,6 +38,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -52,24 +55,20 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/layout.h" -#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/protobuf_util.h" #include "xla/service/hlo.pb.h" -#include "xla/shape_layout.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/translate/hlo_to_mhlo/attribute_importer.h" @@ -329,52 +328,6 @@ static bool IsNestedTupleInData(Type type) { return false; } -static bool HasCustomLayout(const Shape& shape) { - if (shape.IsTuple()) { - return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); - } - return shape.has_layout() && !shape.layout().minor_to_major().empty() && - shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); -} - -static std::pair GetLayoutAttribute( - mlir::Builder& b, const Shape& shape, - std::optional maybe_layout = std::nullopt) { - if (shape.IsTuple()) { - llvm::SmallVector element_attrs; - llvm::SmallVector tile_attrs; - for (const auto& tuple_shape : shape.tuple_shapes()) { - // TODO here we do not disect the layout of a tuple into sublayouts. - // Presently ShapeLayout cannot represent an explicit layout for a tuple - // type so this should never occur. However, if this function were to - // be used in another context where this assumption were to be lifted. - // users should be aware of this limitation which will use the default - // layout for tuple subshapes. - std::pair inner = - GetLayoutAttribute(b, tuple_shape); - element_attrs.push_back(inner.first); - tile_attrs.push_back(inner.second); - } - return std::make_pair((mlir::Attribute)b.getArrayAttr(element_attrs), - b.getArrayAttr(tile_attrs)); - } - - Layout layout = maybe_layout.value_or( - shape.has_layout() ? shape.layout() - : LayoutUtil::GetDefaultLayoutForShape(shape)); - - llvm::SmallVector vec_of_tiles; - for (const Tile& tile : layout.tiles()) { - llvm::SmallVector tile_vec = {tile.dimensions().begin(), - tile.dimensions().end()}; - vec_of_tiles.push_back(b.getIndexTensorAttr(tile_vec)); - } - llvm::SmallVector layout_vec = {layout.minor_to_major().begin(), - layout.minor_to_major().end()}; - return std::make_pair(b.getIndexTensorAttr(layout_vec), - b.getArrayAttr(vec_of_tiles)); -} - mlir::Attribute GetFrontendAttributes(mlir::Builder& b, const FrontendAttributes& attributes) { llvm::SmallVector attrs; @@ -608,48 +561,6 @@ absl::StatusOr HloFunctionImporter::ImportAsFunc( builder_->getStringAttr(computation.execution_thread())); } - // The MLIR CPU pipeline assumes default layouts throughout the program. At - // the boundaries, this may not be the case, so layout information needs to - // be propagated to adapt the data layouts. - if (computation.IsEntryComputation()) { - const auto& computation_layout = - computation.parent()->entry_computation_layout(); - if (computation_layout.LayoutIsSet() && - !computation_layout.result_layout().shape().IsTuple()) { - if (HasCustomLayout(computation_layout.result_layout().shape())) { - std::pair layout_attrs = - GetLayoutAttribute(*builder_, - computation_layout.result_layout().shape(), - computation_layout.result_layout().layout()); - function->setAttr("xla_entry_computation_result_layout", - layout_attrs.first); - function->setAttr("xla_entry_computation_result_tiles", - layout_attrs.second); - } - if (llvm::any_of(computation_layout.parameter_layouts(), - [](const ShapeLayout& shape) { - return HasCustomLayout(shape.shape()); - })) { - llvm::SmallVector parameter_layouts; - llvm::SmallVector parameter_tiles; - for (auto& layout : computation_layout.parameter_layouts()) { - std::pair layout_attrs = - GetLayoutAttribute( - *builder_, layout.shape(), - (layout.LayoutIsSet() && !layout.shape().IsTuple()) - ? std::optional(layout.layout()) - : std::nullopt); - parameter_layouts.push_back(layout_attrs.first); - parameter_tiles.push_back(layout_attrs.second); - } - function->setAttr("xla_entry_computation_parameter_layouts", - builder_->getArrayAttr(parameter_layouts)); - function->setAttr("xla_entry_computation_parameter_tiles", - builder_->getArrayAttr(parameter_tiles)); - } - } - } - symbol_table_.insert(function); // Add to the map right away for function calls if map is set. diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index 76037442d52099..d7bd8404b9adaa 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -18,10 +18,12 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -35,9 +37,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/shape_layout.h" #include "xla/shape_util.h" #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/translate/hlo_to_mhlo/module_config_importer.h" #include "xla/xla.pb.h" #include "tsl/platform/errors.h" @@ -151,6 +156,45 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { /*is_main*/ true, flatten_computation_args_result_) .status(); + // The MLIR CPU pipeline assumes default layouts throughout the program. At + // the boundaries, this may not be the case, so layout information needs to + // be propagated to adapt the data layouts. + if (const auto& computation_layout = hlo_module.entry_computation_layout(); + computation_layout.LayoutIsSet() && + !computation_layout.result_layout().shape().IsTuple()) { + if (HasCustomLayout(computation_layout.result_layout().shape())) { + std::pair layout_attrs = + GetLayoutAttribute(builder_, + computation_layout.result_layout().shape(), + computation_layout.result_layout().layout()); + module->setAttr("mhlo.xla_entry_computation_result_layout", + layout_attrs.first); + module->setAttr("mhlo.xla_entry_computation_result_tiles", + layout_attrs.second); + } + if (llvm::any_of(computation_layout.parameter_layouts(), + [](const ShapeLayout& shape) { + return HasCustomLayout(shape.shape()); + })) { + llvm::SmallVector parameter_layouts; + llvm::SmallVector parameter_tiles; + for (auto& layout : computation_layout.parameter_layouts()) { + std::pair layout_attrs = + GetLayoutAttribute( + builder_, layout.shape(), + (layout.LayoutIsSet() && !layout.shape().IsTuple()) + ? std::optional(layout.layout()) + : std::nullopt); + parameter_layouts.push_back(layout_attrs.first); + parameter_tiles.push_back(layout_attrs.second); + } + module->setAttr("mhlo.xla_entry_computation_parameter_layouts", + builder_.getArrayAttr(parameter_layouts)); + module->setAttr("mhlo.xla_entry_computation_parameter_tiles", + builder_.getArrayAttr(parameter_tiles)); + } + } + auto* module_entry_computation = hlo_module.entry_computation(); for (const auto* computation : hlo_module.computations()) TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc( diff --git a/xla/translate/hlo_to_mhlo/hlo_utils.h b/xla/translate/hlo_to_mhlo/hlo_utils.h index 72c30be491e767..81fe60ea109f9f 100644 --- a/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -18,17 +18,27 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ +#include +#include +#include +#include + +#include "absl/status/statusor.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" -#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/utils/convert_op_folder.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -159,6 +169,52 @@ static absl::StatusOr ConvertShapeToType(const Shape& shape, return ConvertTensorShapeToType(shape, builder); } +static std::pair GetLayoutAttribute( + mlir::Builder& b, const Shape& shape, + std::optional maybe_layout = std::nullopt) { + if (shape.IsTuple()) { + llvm::SmallVector element_attrs; + llvm::SmallVector tile_attrs; + for (const auto& tuple_shape : shape.tuple_shapes()) { + // TODO here we do not dissect the layout of a tuple into sublayouts. + // Presently ShapeLayout cannot represent an explicit layout for a tuple + // type so this should never occur. However, if this function were to + // be used in another context where this assumption were to be lifted. + // users should be aware of this limitation which will use the default + // layout for tuple subshapes. + std::pair inner = + GetLayoutAttribute(b, tuple_shape); + element_attrs.push_back(inner.first); + tile_attrs.push_back(inner.second); + } + return std::make_pair((mlir::Attribute)b.getArrayAttr(element_attrs), + b.getArrayAttr(tile_attrs)); + } + + Layout layout = maybe_layout.value_or( + shape.has_layout() ? shape.layout() + : LayoutUtil::GetDefaultLayoutForShape(shape)); + + llvm::SmallVector vec_of_tiles; + for (const Tile& tile : layout.tiles()) { + llvm::SmallVector tile_vec = {tile.dimensions().begin(), + tile.dimensions().end()}; + vec_of_tiles.push_back(b.getIndexTensorAttr(tile_vec)); + } + llvm::SmallVector layout_vec = {layout.minor_to_major().begin(), + layout.minor_to_major().end()}; + return std::make_pair(b.getIndexTensorAttr(layout_vec), + b.getArrayAttr(vec_of_tiles)); +} + +static bool HasCustomLayout(const Shape& shape) { + if (shape.IsTuple()) { + return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); + } + return shape.has_layout() && !shape.layout().minor_to_major().empty() && + shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); +} + } // namespace xla #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ diff --git a/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo b/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo index 253639908966b5..a8b6707dcc6278 100644 --- a/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo +++ b/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo @@ -10,19 +10,21 @@ ENTRY entry { ROOT add = f32[2,3,4]{2,1,0} add(p0, p1) } -// CHECK: func.func @main( -// CHECK-SAME: xla_entry_computation_parameter_layouts -// CHECK-SAME: dense<[0, 1, 2]> -// CHECK-SAME: dense<[1, 2, 0]> -// CHECK-SAME: [dense<[1, 0]> -// CHECK-SAME: , dense<[0, 1]> -// CHECK-SAME: xla_entry_computation_parameter_tiles = [ -// CHECK-SAME: [] -// CHECK-SAME: [], -// CHECK-SAME: [ +// CHECK: module @entry +// CHECK-SAME: mhlo.xla_entry_computation_parameter_layouts = [ +// CHECK-SAME: dense<[0, 1, 2]> : tensor<3xindex>, +// CHECK-SAME: dense<[1, 2, 0]> : tensor<3xindex>, +// CHECK-SAME: [dense<[1, 0]> : tensor<2xindex>, +// CHECK-SAME: dense<[0, 1]> : tensor<2xindex>], +// CHECK-SAME: dense<> : tensor<0xindex>] +// CHECK-SAME: mhlo.xla_entry_computation_parameter_tiles = [ +// CHECK-SAME: [], +// CHECK-SAME: [], +// CHECK-SAME: [ // CHECK-SAME: [], // CHECK-SAME: [] -// CHECK-SAME: ], -// CHECK-SAME: [dense<128> : tensor<1xindex>] -// CHECK-SAME: ] -// CHECK-SAME: xla_entry_computation_result_layout = dense<[2, 0, 1]> +// CHECK-SAME: ], +// CHECK-SAME: [dense<128> : tensor<1xindex>] +// CHECK-SAME: ], +// CHECK-SAME: mhlo.xla_entry_computation_result_layout = dense<[2, 0, 1]> : tensor<3xindex> +// CHECK-SAME: mhlo.xla_entry_computation_result_tiles = [] From e6ff2a092cb7bd7afb233dd9d5b2159e1a55b110 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 31 Jul 2024 16:20:04 -0700 Subject: [PATCH 344/376] Use Platform::ExecutorForDevice rather than GetExecutor, as device ordinal is all that's being searched for. PiperOrigin-RevId: 658179612 --- xla/pjrt/pjrt_stream_executor_client_test.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xla/pjrt/pjrt_stream_executor_client_test.cc b/xla/pjrt/pjrt_stream_executor_client_test.cc index 19f1c150ef232b..2fa381df57290a 100644 --- a/xla/pjrt/pjrt_stream_executor_client_test.cc +++ b/xla/pjrt/pjrt_stream_executor_client_test.cc @@ -46,10 +46,8 @@ absl::StatusOr> GetClient() { LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie(); TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform("Host")); - se::StreamExecutorConfig config; - config.ordinal = 0; TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - platform->GetExecutor(config)); + platform->ExecutorForDevice(0)); auto device_state = std::make_unique( executor, local_client, LocalDeviceState::kSynchronous, /*max_inflight_computations=*/32, From 696bc339117d832ef66fac960d159daa1243a87f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 16:53:31 -0700 Subject: [PATCH 345/376] Prioritize delaying async-start that has async depth of 0, before looking at "kLessStall", with an additional flag. PiperOrigin-RevId: 658189179 --- xla/service/latency_hiding_scheduler.cc | 49 ++++++++++++++++++------- xla/service/latency_hiding_scheduler.h | 2 +- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index a17bc63b6f8804..5f7757bcd2056a 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -802,7 +802,8 @@ class ReadySetLt { return *value; } } - // Otherwise pick a node that increases the pressure from the list. + // Otherwise pick a node that increases the pressure the least from the + // list. if (auto value = DefaultSchedulerCore::ChooseBestCandidate( a_increase.first < b_increase.first, a, b_increase.first < a_increase.first, b, @@ -880,6 +881,36 @@ class ReadySetLt { } } + auto async_depth_0_candidate = + [this](DefaultSchedulerCore::ScheduleCandidate& a, + DefaultSchedulerCore::ScheduleCandidate& b) + -> std::optional { + // If an instruction releasing a resource is not resource constrained and + // has an async depth of 0, delay it as much as possible to avoid + // potential cost model inefficiencies. For example, if a pair of + // async-start and async-done have no dependencies on other ops inside a + // loop, the async-start will be pushed to the beginning of the loop. + if (auto value = DefaultSchedulerCore::ChooseBestCandidate( + /*first_cond=*/!(a.node->DoesReleaseAnyResource() && + a.node->GetAsyncDepth() == 0 && + !IsResourceConstrained(a)), + a, + /*second_cond=*/ + !(b.node->DoesReleaseAnyResource() && + b.node->GetAsyncDepth() == 0 && !IsResourceConstrained(b)), + b, "kStartAtZeroDepth")) { + return value; + } + return std::nullopt; + }; + + if (sched_state_.config.aggressive_scheduling_policies && + sched_state_.config.prioritize_async_depth_over_stall) { + if (auto value = async_depth_0_candidate(a, b)) { + return *value; + } + } + const ApproximateLatencyEstimator::TimeCost a_ready_interval = std::max(a.node->GetReadyTime() - sched_state_.current_time, 0.0); const ApproximateLatencyEstimator::TimeCost b_ready_interval = @@ -906,19 +937,9 @@ class ReadySetLt { return *value; } } - if (sched_state_.config.aggressive_scheduling_policies) { - // If an instruction releasing a resource is not resource constrained and - // has an async depth of 0, delay it as much as possible to avoid - // potential cost model inefficiencies. - if (auto value = DefaultSchedulerCore::ChooseBestCandidate( - /*first_cond=*/!(a.node->DoesReleaseAnyResource() && - a.node->GetAsyncDepth() == 0 && - !IsResourceConstrained(a)), - a, - /*second_cond=*/ - !(b.node->DoesReleaseAnyResource() && - b.node->GetAsyncDepth() == 0 && !IsResourceConstrained(b)), - b, "kStartAtZeroDepth")) { + if (sched_state_.config.aggressive_scheduling_policies && + !sched_state_.config.prioritize_async_depth_over_stall) { + if (auto value = async_depth_0_candidate(a, b)) { return *value; } } diff --git a/xla/service/latency_hiding_scheduler.h b/xla/service/latency_hiding_scheduler.h index 76ce8b307f7184..ebe1cf0c6bcc8c 100644 --- a/xla/service/latency_hiding_scheduler.h +++ b/xla/service/latency_hiding_scheduler.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_LATENCY_HIDING_SCHEDULER_H_ #define XLA_SERVICE_LATENCY_HIDING_SCHEDULER_H_ -#include #include #include #include @@ -132,6 +131,7 @@ struct SchedulerConfig { bool force_send_recv_to_use_same_resource = false; bool use_real_cost_model = false; bool aggressive_scheduling_policies = false; + bool prioritize_async_depth_over_stall = false; bool enable_release_start_policy = false; bool resource_sharing = false; bool resource_serializing = false; From 5e694e1fe142151223333216f10b7dbef022dac7 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Wed, 31 Jul 2024 17:11:16 -0700 Subject: [PATCH 346/376] [XLA:GPU] Add variant to pipeline parallelism tests that breaks the direct data dependency This variant breaks the direct data dependency between the previous iteration's compute and the collective permute. PiperOrigin-RevId: 658194076 --- .../collective_pipeline_parallelism_test.cc | 194 ++++++++++++++++++ 1 file changed, 194 insertions(+) diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index 48641e2c17cc52..bfcf5e14adf5d1 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -743,5 +743,199 @@ XLA_TEST_F(CollectivePipelineParallelismTest, ErrorSpec{1e-5, 1e-5})); } +// Naive implementation of pipeline parallelism, which breaks the direct data +// dependency between the collective permute and the previous iteration's +// compute. +// - 4 devices +// - 4 microbatches +// - 2 circular repeat +// - no disabled collectives +// - no collective pipelining +// +// Every stage of the pipeline is a single linear layer. +XLA_TEST_F(CollectivePipelineParallelismTest, + NaiveWoDirectBufferDependencyDFSMicrobatch5CircularRepeat2Replica4) { + const absl::string_view kModuleStr = R"( + HloModule test + + get_circ_buffer_index { + offset = u32[] parameter(0) + index = u32[] parameter(1) + size = u32[] parameter(2) + t0 = u32[] add(offset, index) + t1 = u32[] divide(t0, size) + t2 = u32[] multiply(t1, size) + ROOT t4 = u32[] subtract(t0, t2) + } + + read_buffer { + buffer = f32[5,16] parameter(0) + offset = u32[] parameter(1) + index = u32[] parameter(2) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + slice = f32[1,16] dynamic-slice(buffer, index__, c0), + dynamic_slice_sizes={1,16} + ROOT slice_ = f32[16] reshape(slice) + } + + update_buffer { + buffer = f32[5,16] parameter(0) + update = f32[16] parameter(1) + offset = u32[] parameter(2) + index = u32[] parameter(3) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + update_ = f32[1,16] reshape(update) + ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0) + } + + is_input_replica { + replica_id = u32[] replica-id() + c0 = u32[] constant(0) + ROOT predicate = pred[] compare(replica_id, c0), direction=EQ + } + + is_output_replica { + replica_id = u32[] replica-id() + c3 = u32[] constant(3) + ROOT predicate = pred[] compare(replica_id, c3), direction=EQ + } + + is_read_input { + is_input_replica = pred[] call(), to_apply=is_input_replica + i = u32[] parameter(0) + c5 = u32[] constant(5) + is_input_iteration = pred[] compare(i, c5), direction=LT + ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) + } + + while_condition { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + parameter(0) + i = u32[] get-tuple-element(tuple), index=5 + n = u32[] constant(13) + ROOT predicate = pred[] compare(i, n), direction=LT + } + + while_body { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + parameter(0) + weights = f32[16,16] get-tuple-element(tuple), index=0 + input = f32[5,16] get-tuple-element(tuple), index=1 + output = f32[5,16] get-tuple-element(tuple), index=2 + buffer = f32[5,16] get-tuple-element(tuple), index=3 + prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4 + i = u32[] get-tuple-element(tuple), index=5 + + c0 = u32[] constant(0) + c1 = u32[] constant(1) + c2 = u32[] constant(2) + c3 = u32[] constant(3) + c4 = u32[] constant(4) + c5 = u32[] constant(5) + + input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index + input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), + dynamic_slice_sizes={1,16} + input_slice_ = f32[16] reshape(input_slice) + + buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer + + buffer_ = f32[5,16] call(buffer, prev_iteration_compute_out, c4, i), + to_apply=update_buffer + + // Depends on the non-updated buffer of the previous iteration and, + // therefore, does not depend on the previous iteration's compute. + is_output_replica = pred[] call(), to_apply=is_output_replica + next_stage_slice = select(is_output_replica, buffer_slice, + prev_iteration_compute_out) + + + prev_stage_slice = f32[16] collective-permute(next_stage_slice), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + + is_read_input = pred[] call(i), to_apply=is_read_input + compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice) + + compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer + + i_ = add(i, c1) + + ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output_, buffer_, compute_out, i_) + } + + ENTRY main { + weights = f32[16,16] parameter(0) + input = f32[5,16] parameter(1) + + cf0 = f32[] constant(0) + output = f32[5,16] broadcast(cf0), dimensions={} + buffer = f32[5,16] broadcast(cf0), dimensions={} + prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={} + c0 = u32[] constant(0) + + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output, buffer, prev_iteration_compute_out, c0) + tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + while(tuple), condition=while_condition, body=while_body + + ROOT output_ = f32[5,16] get-tuple-element(tuple_), index=2 + } + )"; + + const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // This pipeline consists of a total of 8 layers (2 per replica), each of + // which is a single linear layer. We assign the weights to the replicas such + // that the layers scale the input data by 1.0, 2.0, 3.0 and 4.0 in the first + // and second cycle. The combined effect is to scale the input data by 576.0 + // (24.0 * 24.0). + const int64_t kInputSize = 16; + Literal weights_r0 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 1.0); + Literal weights_r1 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 2.0); + Literal weights_r2 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 3.0); + Literal weights_r3 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 4.0); + + // Only the first replica holds the input to the pipeline in this naive + // implementation. The remaining replicas get zero/dummy input. + const int64_t kMicrobatches = 5; + Literal real_input = + LiteralUtil::CreateFingerprintMatixR2(kMicrobatches, kInputSize); + Literal fake_input = + LiteralUtil::CreateFull({kMicrobatches, kInputSize}, 0.0); + + // Check pipeline output for last replica. + // The combined effect of the pipeline is to scale the input data by 576.0 + // (24.0 * 24.0). + const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0 * 1.0 * 2.0 * 3.0 * 4.0; + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + kMicrobatches, kInputSize, /*scale=*/kExpectedFactor); + std::vector> args = {{&weights_r0, &real_input}, + {&weights_r1, &fake_input}, + {&weights_r2, &fake_input}, + {&weights_r3, &fake_input}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), args, kNumReplicas, + /*run_hlo_passes=*/true)); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], + ErrorSpec{1e-5, 1e-5})); +} + } // namespace } // namespace xla From f024fa039b3e508875ba3439740340cb67436ac5 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 31 Jul 2024 17:12:56 -0700 Subject: [PATCH 347/376] Simplify GpuStream's interface. 1. Remove IsIdle() method, which was only called during destruction in favor of calling the underlying code during the actual destructor. 2. Remove Destroy(), which was only called in the DeallocateStream which was only called in the destructor. Instead, just call the code in the destructor. 3. Make Init() return an absl::Status. 4. Just pass the priority std::optional to the constructor rather than have discrete SetPriority methods taking either of the choices. PiperOrigin-RevId: 658194499 --- xla/stream_executor/cuda/cuda_executor.cc | 25 ++++---------------- xla/stream_executor/gpu/gpu_stream.cc | 19 ++++++++------- xla/stream_executor/gpu/gpu_stream.h | 28 ++++++++--------------- xla/stream_executor/rocm/rocm_executor.cc | 25 ++++---------------- 4 files changed, 31 insertions(+), 66 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index e43b419f1f9207..b624709c22b2f8 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -668,10 +668,6 @@ void GpuExecutor::DeallocateStream(Stream* stream) { GpuStream* gpu_stream = AsGpuStream(stream); absl::MutexLock l(&alive_gpu_streams_mu_); alive_gpu_streams_.erase(gpu_stream->gpu_stream()); - if (!gpu_stream->IsIdle()) { - LOG(ERROR) << "Deallocating stream with pending work"; - } - gpu_stream->Destroy(); } absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { @@ -806,23 +802,12 @@ absl::StatusOr> GpuExecutor::CreateEvent() { absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); - auto stream = std::make_unique(this, std::move(event)); - if (priority.has_value()) { - if (std::holds_alternative(*priority)) { - stream->SetPriority(std::get(*priority)); - } else { - stream->SetPriority(std::get(*priority)); - } - } + auto stream = std::make_unique(this, std::move(event), priority); absl::MutexLock l(&alive_gpu_streams_mu_); - bool init_worked = stream->Init(); - if (init_worked) { - auto gpu_stream = stream->gpu_stream(); - alive_gpu_streams_[gpu_stream] = stream.get(); - return std::move(stream); - } else { - return absl::InvalidArgumentError("Failed to initialize gpu stream"); - } + TF_RETURN_IF_ERROR(stream->Init()); + auto gpu_stream = stream->gpu_stream(); + alive_gpu_streams_[gpu_stream] = stream.get(); + return std::move(stream); } absl::StatusOr> GpuExecutor::CreateCommandBuffer( diff --git a/xla/stream_executor/gpu/gpu_stream.cc b/xla/stream_executor/gpu/gpu_stream.cc index 1f75dfac7ebaf1..706826553e4363 100644 --- a/xla/stream_executor/gpu/gpu_stream.cc +++ b/xla/stream_executor/gpu/gpu_stream.cc @@ -48,7 +48,7 @@ void InternalHostCallback(void* data) { } } // namespace -bool GpuStream::Init() { +absl::Status GpuStream::Init() { int priority = [&]() { if (std::holds_alternative(stream_priority_)) { return std::get(stream_priority_); @@ -58,9 +58,9 @@ bool GpuStream::Init() { }(); if (!GpuDriver::CreateStream(parent_->gpu_context(), &gpu_stream_, priority)) { - return false; + return absl::InternalError("Failed to CreateStream"); } - return true; + return absl::OkStatus(); } Stream::PlatformSpecificHandle GpuStream::platform_specific_handle() const { @@ -172,15 +172,18 @@ absl::Status GpuStream::DoHostCallbackWithStatus( return absl::InternalError("Failed to host callback."); } -void GpuStream::Destroy() { +GpuStream::~GpuStream() { + BlockHostUntilDone().IgnoreError(); + parent()->DeallocateStream(this); + + if (!GpuDriver::IsStreamIdle(parent_->gpu_context(), gpu_stream_)) { + LOG(ERROR) << "Deallocating stream with pending work"; + } + completed_event_.reset(); GpuDriver::DestroyStream(parent_->gpu_context(), &gpu_stream_); } -bool GpuStream::IsIdle() const { - return GpuDriver::IsStreamIdle(parent_->gpu_context(), gpu_stream_); -} - void GpuStream::set_name(absl::string_view name) { name_ = name; tsl::profiler::NameStream( diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index 3b1cc3868172cb..4cf21ca82207ed 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include @@ -48,37 +49,28 @@ class GpuExecutor; // Thread-safe post-initialization. class GpuStream : public StreamCommon { public: - GpuStream(GpuExecutor* parent, std::unique_ptr completed_event) + GpuStream(GpuExecutor* parent, std::unique_ptr completed_event, + std::optional> priority) : StreamCommon(parent), parent_(parent), gpu_stream_(nullptr), - completed_event_(std::move(completed_event)) {} + completed_event_(std::move(completed_event)) { + if (priority.has_value()) { + stream_priority_ = priority.value(); + } + } // Note: teardown is handled by a parent's call to DeallocateStream. - ~GpuStream() override { - BlockHostUntilDone().IgnoreError(); - parent()->DeallocateStream(this); - } + ~GpuStream() override; // Explicitly initialize the CUDA resources associated with this stream. - bool Init(); - - // Sets the priority of this stream. - void SetPriority(StreamPriority priority) { stream_priority_ = priority; } - void SetPriority(int priority) { stream_priority_ = priority; } + absl::Status Init(); std::variant priority() const override { return stream_priority_; } PlatformSpecificHandle platform_specific_handle() const override; - // Explicitly destroy the CUDA resources associated with this stream, used by - // StreamExecutor::DeallocateStream(). - void Destroy(); - - // Returns true if no work is pending or executing on the stream. - bool IsIdle() const; - // Retrieves an event which indicates that all work enqueued into the stream // has completed. Ownership of the event is not transferred to the caller, the // event is owned by this stream. diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index c0da3c6a9bd0f4..cb096c6485e10d 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -500,10 +500,6 @@ void GpuExecutor::DeallocateStream(Stream* stream) { GpuStream* rocm_stream = AsGpuStream(stream); absl::MutexLock l(&alive_gpu_streams_mu_); alive_gpu_streams_.erase(rocm_stream->gpu_stream()); - if (!rocm_stream->IsIdle()) { - LOG(ERROR) << "Deallocating stream with pending work"; - } - rocm_stream->Destroy(); } absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { @@ -643,23 +639,12 @@ absl::StatusOr> GpuExecutor::CreateEvent() { absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); - auto stream = std::make_unique(this, std::move(event)); - if (priority.has_value()) { - if (std::holds_alternative(*priority)) { - stream->SetPriority(std::get(*priority)); - } else { - stream->SetPriority(std::get(*priority)); - } - } + auto stream = std::make_unique(this, std::move(event), priority); absl::MutexLock l(&alive_gpu_streams_mu_); - bool init_worked = stream->Init(); - if (init_worked) { - auto gpu_stream = stream->gpu_stream(); - alive_gpu_streams_[gpu_stream] = stream.get(); - return std::move(stream); - } else { - return absl::InvalidArgumentError("Failed to initialize GPU stream"); - } + TF_RETURN_IF_ERROR(stream->Init()); + auto gpu_stream = stream->gpu_stream(); + alive_gpu_streams_[gpu_stream] = stream.get(); + return std::move(stream); } absl::StatusOr> GpuExecutor::CreateCommandBuffer( From f9011c1dbee841dbbff73e189614111db692973f Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Wed, 31 Jul 2024 17:47:08 -0700 Subject: [PATCH 348/376] Do not directly replicate the operand when another operand and the result have matched sharding on non-contracting dims. ### An simple example Let us take the following dot as an example. ``` lhs = bf16[16384,2048] parameter(0), sharding={devices=[16,8]<=[128]} rhs = bf16[16384,256] parameter(1), sharding={devices=[128,1]<=[128]} ROOT dot = bf16[2048,256] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0}, sharding={devices=[8,1,16]<=[16,8]T(1,0) last_tile_dim_replicate} ``` A good solution is to reshard rhs into `sharding={devices=[16,1,8]<=[128] last_tile_dim_replicate}`. In this way, all three tensors have matched shardings. The partitioner can convert it into dot operation and an all-reduce. Before this cl, the partitioner is suboptimal. It reshards the rhs twice ``` {devices=[128,1]<=[128]} -> replicated -> {devices=[16,1,8]<=[128] last_tile_dim_replicate} ``` It is redundant to fully rematerialize rhs. This cl reshards `rhs` directly from `{devices=[128,1]<=[128]}` to `{devices=[16,1,8]<=[128] last_tile_dim_replicate}`. ### Mechanism Let us use notation `C = dot(A, B)`. When we find that A and C has matched sharding axes along non-contracting dimensions, we intend to remove these matched axes from B. Before this cl, if the removal fails, we replicate B directly. This cl attempts to reshard B to the expected sharding, which is no worse than the last resort (replicated sharding). PiperOrigin-RevId: 658203651 --- xla/service/spmd/dot_handler.cc | 62 +++++++++++++++-------- xla/service/spmd/spmd_partitioner_test.cc | 49 ++++++++++++++++++ 2 files changed, 89 insertions(+), 22 deletions(-) diff --git a/xla/service/spmd/dot_handler.cc b/xla/service/spmd/dot_handler.cc index fb41d24870f405..87db7f723d6432 100644 --- a/xla/service/spmd/dot_handler.cc +++ b/xla/service/spmd/dot_handler.cc @@ -2542,19 +2542,39 @@ absl::StatusOr PartitionDotGroupOnNonContractingImpl( matching.sharding() != UngroupSharding(matching_grouped)) { return nullptr; } + + auto try_sharding_for_other_operand = [&](const HloSharding& sharding) { + PartitionedHlo other_reshard = other.Reshard(sharding); + std::optional grouped_sharding = + GetNonContractingPartitionGroupedShardingForOtherOperand( + lhs_matching, output_base_shape, other_reshard.hlo()->shape(), + other_contracting_partitions, other_non_contracting_partitions, + matching_contracting_partitions, + output_other_non_contracting_partitions, other_reshard.sharding(), + output_sharding, partitioned_non_contracting_dims, + lhs_matching ? dims_mapping.rhs_non_contracting_dims + : dims_mapping.lhs_non_contracting_dims, + dims_mapping.contracting_dims); + if (grouped_sharding) { + other = other_reshard; + } + return grouped_sharding; + }; std::optional other_grouped = - GetNonContractingPartitionGroupedShardingForOtherOperand( - lhs_matching, output_base_shape, other.hlo()->shape(), - other_contracting_partitions, other_non_contracting_partitions, - matching_contracting_partitions, - output_other_non_contracting_partitions, other.sharding(), - output_sharding, partitioned_non_contracting_dims, - lhs_matching ? dims_mapping.rhs_non_contracting_dims - : dims_mapping.lhs_non_contracting_dims, - dims_mapping.contracting_dims); - if (!other_grouped) { - other = other.Replicate(); + try_sharding_for_other_operand(other.sharding()); + if (!other_grouped && !other.sharding().IsReplicated()) { + const HloSharding expected_other_sharding = + hlo_sharding_util::InferDotOperandSharding( + &output_sharding, &matching.sharding(), lhs_matching ? 1 : 0, + dims_mapping, true, true); + // Try the expected sharding since it is no worse than the last resort + // (replicated sharding). + other_grouped = try_sharding_for_other_operand(expected_other_sharding); + if (!other_grouped) { + other = other.Replicate(); + } } + matching = matching.Reshard(UngroupSharding(matching_grouped)); auto per_group_partitioner_state = CreatePerGroupPartitioningState( matching.state(), matching_grouped.device_groups, b); @@ -2573,7 +2593,7 @@ absl::StatusOr PartitionDotGroupOnNonContractingImpl( partially_replicated_other = other.hlo(); top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding()); partially_replicated_other->set_sharding(other_grouped->sharding); - } else if (!other.sharding().IsReplicated()) { + } else if (other_grouped && !other.sharding().IsReplicated()) { HloSharding target_sharding = UngroupSharding(*other_grouped); GroupedSharding target_group_sharding = hlo_sharding_util::GroupShardingOnDims(target_sharding, @@ -2597,18 +2617,16 @@ absl::StatusOr PartitionDotGroupOnNonContractingImpl( partially_replicated_other, partially_replicated_other->sharding()); partially_replicated_other->set_sharding(other_grouped->sharding); } + auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(), per_group_partitioner_state); - TF_ASSIGN_OR_RETURN( - auto dot, - PartitionDot(lhs_matching ? matching_p : other_p, - lhs_matching ? other_p : matching_p, - GetPerGroupBaseShape(output_grouped, output_base_shape), - output_grouped.sharding, dims_mapping, - num_partitions / matching_grouped.device_groups.size(), - create_sharded_dot, conv_window, module, original_hlo, - options, b, windowed_dot_general_loops, visitor)); - return dot; + return PartitionDot(lhs_matching ? matching_p : other_p, + lhs_matching ? other_p : matching_p, + GetPerGroupBaseShape(output_grouped, output_base_shape), + output_grouped.sharding, dims_mapping, + num_partitions / matching_grouped.device_groups.size(), + create_sharded_dot, conv_window, module, original_hlo, + options, b, windowed_dot_general_loops, visitor); } std::pair diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index da872c2f334b84..e24e32f7899866 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -9114,6 +9114,55 @@ ENTRY %main.7 { EXPECT_THAT(root, tuple); } +TEST_P(SpmdPartitioningTest, PartiallyReplicateRHS) { + const char* const hlo_string = R"( +HloModule module +ENTRY main { + lhs = bf16[16384,2048] parameter(0), sharding={devices=[16,8]<=[128]} + rhs = bf16[16384,256] parameter(1), sharding={devices=[128,1]<=[128]} + ROOT dot = bf16[2048,256] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0}, sharding={devices=[8,1,16]<=[16,8]T(1,0) last_tile_dim_replicate} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, PartitionComputation(hlo_string, /*num_devices=*/128)); + VLOG(1) << module->ToString(); + + const auto lhs = AllOf(op::Shape("bf16[1024,256]"), op::Parameter(0)); + const auto rhs = AllOf(op::Shape("bf16[1024,256]"), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), op::Parameter(1), _, _))); + auto dot = AllOf(op::Shape("bf16[256,256]"), op::Dot(lhs, rhs)); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::AllReduce(dot)); +} + +TEST_P(SpmdPartitioningTest, AllToAllAndPartialReplicateRHS) { + const char* const hlo_string = R"( +HloModule module +ENTRY main { + lhs = bf16[64,64] parameter(0), sharding={devices=[2,2,2]<=[8] last_tile_dim_replicate} + rhs = bf16[64,64,64] parameter(1), sharding={devices=[1,2,4]<=[2,2,2]T(2,1,0)} + ROOT dot = bf16[64,64,64] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={2}, sharding={devices=[2,2,1,2]<=[2,2,2]T(0,2,1) last_tile_dim_replicate} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + const auto lhs = AllOf(op::Shape("bf16[32,32]"), op::Parameter(0)); + const auto all_to_all_p1 = AllOf( + op::Shape("bf16[32,64,16]"), + op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(1)))))); + const auto rhs = AllOf(op::Shape("bf16[32,64,32]"), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), all_to_all_p1, _, _, _))); + auto dot = AllOf(op::Shape("bf16[32,32,64]"), op::Dot(lhs, rhs)); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::AllReduce(dot)); +} + TEST_P(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) { absl::string_view hlo_string = R"( HloModule module From 256ae491a3f559aafa0011e76bdeee0cd4641ac8 Mon Sep 17 00:00:00 2001 From: Ce Zheng Date: Wed, 31 Jul 2024 17:48:48 -0700 Subject: [PATCH 349/376] [XLA] Avoid using OpSharding in hlo_parser. PiperOrigin-RevId: 658204128 --- xla/hlo/ir/hlo_sharding.h | 5 ++ xla/service/BUILD | 2 + xla/service/hlo_parser.cc | 151 +++++++++++++++++++------------------- 3 files changed, 82 insertions(+), 76 deletions(-) diff --git a/xla/hlo/ir/hlo_sharding.h b/xla/hlo/ir/hlo_sharding.h index a15d3b33e4c44f..5a7c49e9265899 100644 --- a/xla/hlo/ir/hlo_sharding.h +++ b/xla/hlo/ir/hlo_sharding.h @@ -138,6 +138,11 @@ class HloSharding { static HloSharding Tuple(const Shape& tuple_shape, absl::Span shardings); + // Creates a new sharding for a flat tuple type. + static HloSharding FlatTuple(std::vector sub_shardings) { + return HloSharding(std::move(sub_shardings)); + } + // Creates a new sharding for a tuple type, with a single input sharding // repeated on each leaf. static HloSharding SingleTuple(const Shape& tuple_shape, diff --git a/xla/service/BUILD b/xla/service/BUILD index 8fa7f829b73b42..43df5c1c5b9ce9 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -6918,6 +6918,7 @@ cc_library( ":hlo_proto_cc", ":name_uniquer", ":shape_inference", + "//xla:array", "//xla:comparison_util", "//xla:literal", "//xla:literal_util", @@ -6927,6 +6928,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:tile_assignment", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/service/hlo_parser.cc b/xla/service/hlo_parser.cc index 2c9a480983afa5..8f097ca4179232 100644 --- a/xla/service/hlo_parser.cc +++ b/xla/service/hlo_parser.cc @@ -46,6 +46,7 @@ limitations under the License. #include "absl/strings/strip.h" #include "absl/types/span.h" #include "Eigen/Core" +#include "xla/array.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -58,6 +59,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" +#include "xla/hlo/ir/tile_assignment.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -490,12 +492,11 @@ class HloParserImpl : public HloParser { bool ParseWindow(Window* window, bool expect_outer_curlies); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); - bool ParseMetadata(OpMetadata* metadata); - bool ParseSingleOrListMetadata( - tsl::protobuf::RepeatedPtrField* metadata); + bool ParseMetadata(OpMetadata& metadata); + bool ParseSingleOrListMetadata(std::vector& metadata); bool ParseOpShardingType(OpSharding::Type* type); bool ParseListShardingType(std::vector* types); - bool ParseSharding(OpSharding* sharding); + bool ParseSharding(std::optional& sharding); bool ParseCollectiveDeviceList(CollectiveDeviceList* device_list); bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes); bool ParseStatisticsViz(StatisticsViz* statistics_viz); @@ -503,7 +504,8 @@ class HloParserImpl : public HloParser { std::vector& iota_reshape_dims, std::vector& iota_transpose_perm, std::vector* devices); - bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + bool ParseSingleSharding(std::optional& sharding, + bool lbrace_pre_lexed); bool ParseParameterReplication(ParameterReplication* parameter_replication); bool ParseBooleanListOrSingleBoolean(BoolList* boolean_list); bool ParseReplicaGroupsOnly(std::vector* replica_groups); @@ -1362,7 +1364,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, // Add optional attributes. These are added to any HloInstruction type if // present. absl::flat_hash_map attrs; - optional sharding; + optional sharding; optional frontend_attributes; optional statistics_viz; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; @@ -1423,9 +1425,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, // TODO(b/257495070): Eliminate tuple sharding normalization in HLO parser. // Allow existing HLO text with invalid sharding on tuple shapes by // normalizing tuple sharding. - HloSharding hlo_sharding = HloSharding::FromProto(sharding.value()).value(); - hlo_sharding = hlo_sharding.NormalizeTupleSharding(instruction->shape()); - instruction->set_sharding(std::move(hlo_sharding)); + instruction->set_sharding( + sharding->NormalizeTupleSharding(instruction->shape())); } if (parameter_replication) { int leaf_count = ShapeUtil::GetLeafCount(instruction->shape()); @@ -3369,7 +3370,7 @@ bool HloParserImpl::ParseCollectiveDeviceList( // ::= '{' (single_sharding | tuple_sharding) '}' // // tuple_sharding ::= single_sharding* (',' single_sharding)* -bool HloParserImpl::ParseSharding(OpSharding* sharding) { +bool HloParserImpl::ParseSharding(std::optional& sharding) { // A single sharding starts with '{' and is not followed by '{'. // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for // an empty tuple. @@ -3385,15 +3386,18 @@ bool HloParserImpl::ParseSharding(OpSharding* sharding) { // Tuple sharding. // Allow empty tuple shardings. + std::vector tuple_shardings; if (lexer_.GetKind() != TokKind::kRbrace) { do { - if (!ParseSingleSharding(sharding->add_tuple_shardings(), + std::optional tuple_sharding; + if (!ParseSingleSharding(tuple_sharding, /*lbrace_pre_lexed=*/false)) { return false; } + tuple_shardings.push_back(std::move(*tuple_sharding)); } while (EatIfPresent(TokKind::kComma)); } - sharding->set_type(OpSharding::TUPLE); + sharding = HloSharding::FlatTuple(std::move(tuple_shardings)); return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); } @@ -3575,7 +3579,7 @@ bool HloParserImpl::ParseTileAssignment( // metadata ::= single_metadata | // ('{' [single_metadata (',' single_metadata)*] '}') // last_tile_dims ::= sharding_type_list -bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, +bool HloParserImpl::ParseSingleSharding(std::optional& sharding, bool lbrace_pre_lexed) { if (!lbrace_pre_lexed && !ParseToken(TokKind::kLbrace, @@ -3598,6 +3602,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, std::vector iota_reshape_dims; std::vector iota_transpose_perm; std::vector subgroup_types; + std::vector metadata; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -3632,7 +3637,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, } } else if (lexer_.GetStrVal() == "metadata") { lexer_.Lex(); - if (!ParseSingleOrListMetadata(sharding->mutable_metadata())) { + if (!ParseSingleOrListMetadata(metadata)) { return false; } } else if (lexer_.GetStrVal() == "last_tile_dims") { @@ -3680,26 +3685,25 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, return Error(loc, "replicated shardings should not have any devices assigned"); } - sharding->set_type(OpSharding::REPLICATED); + sharding = HloSharding::Replicate(metadata); } else if (maximal) { if (devices.size() != 1) { return Error(loc, "maximal shardings should have exactly one device assigned"); } - sharding->set_type(OpSharding::MAXIMAL); - sharding->add_tile_assignment_devices(devices[0]); + sharding = HloSharding::AssignDevice(devices[0], metadata); } else if (manual) { if (!devices.empty()) { return Error(loc, "manual shardings should not have any devices assigned"); } - sharding->set_type(OpSharding::MANUAL); + sharding = HloSharding::Manual(metadata); } else if (unknown) { if (!devices.empty()) { return Error(loc, "unknown shardings should not have any devices assigned"); } - sharding->set_type(OpSharding::UNKNOWN); + sharding = HloSharding::Unknown(metadata); } else { if (tile_assignment_dimensions.empty()) { return Error( @@ -3707,10 +3711,6 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, "non-maximal shardings must have a tile assignment list including " "dimensions"); } - sharding->set_type(OpSharding::OTHER); - for (int64_t dim : tile_assignment_dimensions) { - sharding->add_tile_assignment_dimensions(dim); - } if (iota_transpose_perm.size() != iota_reshape_dims.size()) { return Error(loc, absl::StrFormat( @@ -3718,44 +3718,41 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, "iota_reshape_dims : expected %lld, saw %lld.", iota_reshape_dims.size(), iota_transpose_perm.size())); } + if (last_tile_dim_replicate) { + CHECK(subgroup_types.empty()); + subgroup_types.push_back(OpSharding::REPLICATED); + } if (!iota_reshape_dims.empty()) { CHECK(devices.empty()); - absl::c_copy(iota_reshape_dims, - tsl::protobuf::RepeatedFieldBackInserter( - sharding->mutable_iota_reshape_dims())); - absl::c_copy(iota_transpose_perm, - tsl::protobuf::RepeatedFieldBackInserter( - sharding->mutable_iota_transpose_perm())); + sharding = + subgroup_types.empty() + ? HloSharding::IotaTile(tile_assignment_dimensions, + iota_reshape_dims, iota_transpose_perm, + metadata) + : HloSharding::Subgroup( + TileAssignment(tile_assignment_dimensions, + iota_reshape_dims, iota_transpose_perm), + subgroup_types, metadata); } else { if (devices.size() <= 1) { return Error( loc, "non-maximal shardings must have more than one device assigned"); } - for (int64_t device : devices) { - sharding->add_tile_assignment_devices(device); - } - } - - if (last_tile_dims) { - for (OpSharding::Type type : subgroup_types) { - sharding->add_last_tile_dims(type); - } - } else { - sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate); + auto tiles = std::make_shared>(tile_assignment_dimensions); + absl::c_copy(devices, tiles->begin()); + sharding = + subgroup_types.empty() + ? HloSharding::Tile(TileAssignment(std::move(tiles)), metadata) + : HloSharding::Subgroup(TileAssignment(std::move(tiles)), + subgroup_types, metadata); } } if (shard_as || shard_like) { - sharding->set_is_shard_group(true); - sharding->set_shard_group_id(shard_group_id); - if (shard_as) { - sharding->set_shard_group_type(OpSharding::AS); - } else { - sharding->set_shard_group_type(OpSharding::LIKE); - } - } else { - sharding->set_is_shard_group(false); + sharding = sharding->SetShardGroup( + shard_as ? HloSharding::ShardAs(shard_group_id) + : HloSharding::ShardLike(shard_group_id)); } lexer_.Lex(); @@ -3853,8 +3850,8 @@ bool HloParserImpl::ParseReplicaGroupsOnly( bool HloParserImpl::ParseDomain(DomainData* domain) { absl::flat_hash_map attrs; optional kind; - optional entry_sharding; - optional exit_sharding; + optional entry_sharding; + optional exit_sharding; attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; @@ -3862,10 +3859,10 @@ bool HloParserImpl::ParseDomain(DomainData* domain) { return false; } if (*kind == ShardingMetadata::KindName()) { - auto entry_sharding_ptr = std::make_unique( - HloSharding::FromProto(*entry_sharding).value()); - auto exit_sharding_ptr = std::make_unique( - HloSharding::FromProto(*exit_sharding).value()); + auto entry_sharding_ptr = + std::make_unique(std::move(*entry_sharding)); + auto exit_sharding_ptr = + std::make_unique(std::move(*exit_sharding)); domain->entry_metadata = std::make_unique(std::move(entry_sharding_ptr)); domain->exit_metadata = @@ -4823,11 +4820,12 @@ bool HloParserImpl::ParseAttributeHelper( return true; } case AttrTy::kSharding: { - OpSharding sharding; - if (!ParseSharding(&sharding)) { + std::optional sharding; + if (!ParseSharding(sharding)) { return false; } - static_cast*>(attr_out_ptr)->emplace(sharding); + static_cast*>(attr_out_ptr) + ->emplace(std::move(*sharding)); return true; } case AttrTy::kCollectiveDeviceList: { @@ -4958,10 +4956,11 @@ bool HloParserImpl::ParseAttributeHelper( } case AttrTy::kMetadata: { OpMetadata result; - if (!ParseMetadata(&result)) { + if (!ParseMetadata(result)) { return false; } - static_cast*>(attr_out_ptr)->emplace(result); + static_cast*>(attr_out_ptr) + ->emplace(std::move(result)); return true; } case AttrTy::kDistribution: { @@ -6304,7 +6303,7 @@ bool HloParserImpl::ParseOriginalValue( } // '{' metadata_string '}' -bool HloParserImpl::ParseMetadata(OpMetadata* metadata) { +bool HloParserImpl::ParseMetadata(OpMetadata& metadata) { absl::flat_hash_map attrs; optional op_type; optional op_name; @@ -6330,42 +6329,42 @@ bool HloParserImpl::ParseMetadata(OpMetadata* metadata) { return false; } if (op_type) { - metadata->set_op_type(*op_type); + metadata.set_op_type(*op_type); } if (op_name) { - metadata->set_op_name(*op_name); + metadata.set_op_name(*op_name); } if (source_file) { - metadata->set_source_file(*source_file); + metadata.set_source_file(*source_file); } if (source_line) { - metadata->set_source_line(*source_line); + metadata.set_source_line(*source_line); } if (profile_type) { for (const auto& type : *profile_type) { if (!ProfileType_IsValid(type)) { return false; } - metadata->add_profile_type(static_cast(type)); + metadata.add_profile_type(static_cast(type)); } } if (deduplicated_name) { - metadata->set_deduplicated_name(*deduplicated_name); + metadata.set_deduplicated_name(*deduplicated_name); } if (preserve_layout) { - metadata->set_preserve_layout(*preserve_layout); + metadata.set_preserve_layout(*preserve_layout); } else { - metadata->set_preserve_layout(false); + metadata.set_preserve_layout(false); } if (scheduling_name) { - metadata->set_scheduling_name(*scheduling_name); + metadata.set_scheduling_name(*scheduling_name); } return true; } // ::= single_metadata | ('{' [single_metadata (',' single_metadata)*] '}') bool HloParserImpl::ParseSingleOrListMetadata( - tsl::protobuf::RepeatedPtrField* metadata) { + std::vector& metadata) { if (lexer_.GetKind() == TokKind::kLbrace && lexer_.LookAhead() == TokKind::kLbrace) { if (!ParseToken(TokKind::kLbrace, "expected '{' to start metadata list")) { @@ -6374,7 +6373,7 @@ bool HloParserImpl::ParseSingleOrListMetadata( if (lexer_.GetKind() != TokKind::kRbrace) { do { - if (!ParseMetadata(metadata->Add())) { + if (!ParseMetadata(metadata.emplace_back())) { return false; } } while (EatIfPresent(TokKind::kComma)); @@ -6383,7 +6382,7 @@ bool HloParserImpl::ParseSingleOrListMetadata( return ParseToken(TokKind::kRbrace, "expected '}' to end metadata list"); } - return ParseMetadata(metadata->Add()); + return ParseMetadata(metadata.emplace_back()); } bool HloParserImpl::ParseOpShardingType(OpSharding::Type* type) { @@ -6758,14 +6757,14 @@ absl::StatusOr HloParserImpl::ParseLayoutOnly() { absl::StatusOr HloParserImpl::ParseShardingOnly() { lexer_.Lex(); - OpSharding op_sharding; - if (!ParseSharding(&op_sharding)) { + std::optional sharding; + if (!ParseSharding(sharding)) { return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after sharding"); } - return HloSharding::FromProto(op_sharding); + return std::move(*sharding); } absl::StatusOr From f6755fa00c4bcde7d814ec48924cc726c1cf103b Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Wed, 31 Jul 2024 18:30:32 -0700 Subject: [PATCH 350/376] Cleanup. Remove unused argument. PiperOrigin-RevId: 658215100 --- xla/service/sharding_propagation.cc | 7 +++---- xla/service/sharding_propagation.h | 3 +-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index 0a8e3cf14a42f4..16a01524e959e8 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -2037,8 +2037,7 @@ bool InferDynamicUpdateSliceShardingFromOperand0( } bool ShardingPropagation::InferShardingFromShardGroup( - HloInstruction* instruction, const ComputationMap& computation_map, - int64_t aggressiveness, + HloInstruction* instruction, int64_t aggressiveness, const absl::flat_hash_set& shard_group) { if (!CanPropagateThroughAtAggressiveLevel(*instruction, aggressiveness)) { return false; @@ -3134,8 +3133,8 @@ absl::StatusOr ShardingPropagation::Run( continue; } already_inferred_from_shard_group.insert(instruction); - if (InferShardingFromShardGroup(instruction, computation_map, - aggressiveness, shard_group)) { + if (InferShardingFromShardGroup(instruction, aggressiveness, + shard_group)) { ++inferred_from_shard_group_counter; any_changed = true; VLOG(2) << "Add sharding (shard group): " diff --git a/xla/service/sharding_propagation.h b/xla/service/sharding_propagation.h index 66be9e7e501e32..22cb7af042545d 100644 --- a/xla/service/sharding_propagation.h +++ b/xla/service/sharding_propagation.h @@ -140,8 +140,7 @@ class ShardingPropagation : public HloModulePass { private: bool InferShardingFromShardGroup( - HloInstruction* instruction, const ComputationMap& computation_map, - int64_t aggressiveness, + HloInstruction* instruction, int64_t aggressiveness, const absl::flat_hash_set& shard_group); bool InferShardingFromOperands( HloInstruction* instruction, const ComputationMap& computation_map, From 9f400365ff16b4736d6f0ee29b761f2bc04428e7 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Wed, 31 Jul 2024 20:09:05 -0700 Subject: [PATCH 351/376] [XLA:UNSTACKER] Add another case of unstacking The following case is now supported: fusion(stacked, loop_iteration_var) computation { p0 = parameter(0) p1 = parameter(1) ROOT slice = dynamic_slice(p0, p1, zero, ...) } Moreover, the unstacking pass now accepts a lambda that decides whether to unfuse the slicing instruction within the unstacking computation or not. This feature PiperOrigin-RevId: 658239712 --- xla/service/hlo_unstacker.cc | 103 +++++++++++++++++++++++------ xla/service/hlo_unstacker.h | 9 ++- xla/service/hlo_unstacker_test.cc | 106 ++++++++++++++++++++++++++++++ 3 files changed, 198 insertions(+), 20 deletions(-) diff --git a/xla/service/hlo_unstacker.cc b/xla/service/hlo_unstacker.cc index 024a41b0c48417..c6b0971f4f3312 100644 --- a/xla/service/hlo_unstacker.cc +++ b/xla/service/hlo_unstacker.cc @@ -54,6 +54,7 @@ namespace { // TODO: b/352400145 - Unify the patterns, handlers and their type into a class // or struct. enum class PatternType { + DSFusionNoBitcastPattern, DSFusionPattern, NestedDSFusionPattern, Other, @@ -61,6 +62,8 @@ enum class PatternType { static std::string PatternTypeToString(PatternType pattern_type) { switch (pattern_type) { + case PatternType::DSFusionNoBitcastPattern: + return "DSFusionNoBitcastPattern"; case PatternType::DSFusionPattern: return "DSFusionPattern"; case PatternType::NestedDSFusionPattern: @@ -97,7 +100,8 @@ struct PatternInfo { // information for unstacking that is fixed across different unstacker // instastances. struct UnstackerMetadata { - static absl::StatusOr Create(HloModule* module) { + static absl::StatusOr Create( + HloModule* module, std::function unfuse_slice) { UnstackerMetadata metadata; TF_ASSIGN_OR_RETURN( bool prepared, @@ -111,6 +115,7 @@ struct UnstackerMetadata { metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config; metadata.bodies[instr->while_body()] = instr; } + metadata.unfuse_slice = unfuse_slice; return metadata; } absl::flat_hash_map unrollable_loop_bodies; @@ -123,6 +128,7 @@ struct UnstackerMetadata { const UnstackerMetadata&, const HloInstruction*, int64_t)>, std::function>> custom_handlers; + std::function unfuse_slice; }; // Performs the two-step unstacking. Each instance of this class is responsible @@ -198,7 +204,7 @@ class UnstackerTransformer { return {}; } - const UnstackerMetadata& GetMetadata() { return metadata_; } + const UnstackerMetadata& GetMetadata() const { return metadata_; } std::vector& GetUnstackedInstructions() { return unstacked_instrs_; @@ -440,9 +446,18 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker, // later prefetched using async-slice by MSA. For other patterns, we // resort to the original unstacking computation until we find benefit in // doing otherwise. + HloInstruction* slice = nullptr; if (unstacker.GetPatternType() == PatternType::DSFusionPattern || - unstacker.GetPatternType() == PatternType::NestedDSFusionPattern) { - HloInstruction* dynamic_slice = root_instr->mutable_operand(0); + unstacker.GetPatternType() == PatternType::NestedDSFusionPattern || + unstacker.GetPatternType() == PatternType::DSFusionNoBitcastPattern) { + HloInstruction* dynamic_slice = nullptr; + if (unstacker.GetPatternType() == PatternType::DSFusionPattern || + unstacker.GetPatternType() == PatternType::NestedDSFusionPattern) { + dynamic_slice = root_instr->mutable_operand(0); + } else if (unstacker.GetPatternType() == + PatternType::DSFusionNoBitcastPattern) { + dynamic_slice = root_instr; + } std::vector new_start_indices; new_start_indices.reserve(dynamic_slice->shape().rank()); std::vector new_limit_indices; @@ -458,25 +473,22 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker, dynamic_slice->mutable_operand(0)->shape().dimensions(j)); new_strides.push_back(1); } - HloInstruction* slice = - while_instr->AddInstruction(HloInstruction::CreateSlice( - dynamic_slice->shape(), old_while_input, new_start_indices, - new_limit_indices, new_strides)); - - slices.push_back(slice); - } else { + slice = while_instr->AddInstruction(HloInstruction::CreateSlice( + dynamic_slice->shape(), old_while_input, new_start_indices, + new_limit_indices, new_strides)); + } + if (slice == nullptr || !unstacker.GetMetadata().unfuse_slice(slice)) { std::vector operands = { old_while_input, while_instr->AddInstruction(MakeScalarConstantWithShape( unstacking_computation->parameter_instruction(1)->shape(), i))}; - HloInstruction* slice = - while_instr->AddInstruction(HloInstruction::CreateFusion( - slice_shape, HloInstruction::FusionKind::kLoop, operands, - while_instr->GetModule()->AddEmbeddedComputation( - unstacking_computation->Clone()), - "hoisted")); - slices.push_back(slice); + slice = while_instr->AddInstruction(HloInstruction::CreateFusion( + slice_shape, HloInstruction::FusionKind::kLoop, operands, + while_instr->GetModule()->AddEmbeddedComputation( + unstacking_computation->Clone()), + "hoisted")); } + slices.push_back(slice); } } HloInstruction* new_operand_element = @@ -788,6 +800,56 @@ absl::Status UnstackDSFusionPattern( bitcast_fusion); } +// This function recognizes fusions with the following pattern: +// fusion(stacked, f(loop_iteration_var)) +// computation { +// p0 = parameter(0) +// p1 = parameter(1) +// ROOT slice = dynamic_slice(p0, p1, zero, ...) +// } +// where f is a function of loop_iteration_var. It indicates that the slicing +// offset is effectively static after unrolling. +std::optional GetDSFusionNoBitcastPattern( + const UnstackerMetadata& metadata, const HloInstruction* instr, + int64_t stacked_operand_idx) { + VLOG(3) << "Checking DSFusionNoBitcast"; + HloInstruction* shape_covering_instr = + GetMostMajorEffectivelyStaticDynamicSliceInFusion(metadata, instr, 2, + stacked_operand_idx); + if (shape_covering_instr == nullptr) { + return std::nullopt; + } + if (instr->fused_instructions_computation()->root_instruction() != + shape_covering_instr) { + return std::nullopt; + } + PatternInfo pattern_info; + pattern_info.type = PatternType::DSFusionNoBitcastPattern; + pattern_info.instr = instr; + const Shape& slice_shape = shape_covering_instr->shape(); + const int64_t num_layers = instr->operand(0)->shape().dimensions(0); + pattern_info.unstacked_shape = + MakeUnstackedShapeFromSlice(slice_shape, num_layers); + pattern_info.unstacking_computation = instr->fused_instructions_computation(); + pattern_info.unstacked_instrs.push_back(instr); + return pattern_info; +} + +absl::Status UnstackDSFusionNoBitcastPattern( + HloInstruction* mutable_dynamic_slicing_fusion, const Shape& slice_shape) { + HloComputation* parent_loop = mutable_dynamic_slicing_fusion->parent(); + + HloInstruction* stacked = mutable_dynamic_slicing_fusion->mutable_operand(0); + HloInstruction* offset = mutable_dynamic_slicing_fusion->mutable_operand(1); + + HloInstruction* new_operand = + parent_loop->AddInstruction(HloInstruction::CreateCustomCall( + slice_shape, {stacked, offset}, "DynamicGte")); + + return mutable_dynamic_slicing_fusion->ReplaceAllUsesWithDifferentShape( + new_operand); +} + // This function recognizes fusions with the following pattern: // fusion(stacked, update, loop_iteration_var) // computation { @@ -1290,7 +1352,8 @@ absl::Status UnstackReduceFusionPattern(HloInstruction* mutable_reduce_fusion, absl::StatusOr HloUnstacker::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(auto metadata, UnstackerMetadata::Create(module)); + TF_ASSIGN_OR_RETURN(auto metadata, + UnstackerMetadata::Create(module, unfuse_slice_)); // The order of the patterns below is important, as it determines the order // in which the unstacking custom handlers are called. For example, applying // GetDSAndDUSPattern after GetDSFusionPattern would result in patterns of @@ -1310,6 +1373,8 @@ absl::StatusOr HloUnstacker::Run( std::make_pair(GetReduceFusionPattern, UnstackReduceFusionPattern)); metadata.custom_handlers.push_back( std::make_pair(GetNestedDSFusionPattern, UnstackNestedDSFusionPattern)); + metadata.custom_handlers.push_back(std::make_pair( + GetDSFusionNoBitcastPattern, UnstackDSFusionNoBitcastPattern)); std::vector entry_loops; for (HloInstruction* instr : diff --git a/xla/service/hlo_unstacker.h b/xla/service/hlo_unstacker.h index eaa74ffc003468..222a1e511e6d47 100644 --- a/xla/service/hlo_unstacker.h +++ b/xla/service/hlo_unstacker.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include + #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -79,13 +81,18 @@ class HloUnstacker : public HloModulePass { public: ~HloUnstacker() override = default; - explicit HloUnstacker() = default; + explicit HloUnstacker(std::function unfuse_slice = + [](HloInstruction* instr) { return true; }) + : unfuse_slice_(unfuse_slice) {} absl::string_view name() const override { return "hlo_unstacker"; } using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + std::function unfuse_slice_; }; } // namespace xla diff --git a/xla/service/hlo_unstacker_test.cc b/xla/service/hlo_unstacker_test.cc index 84724550052dc1..37a9843b85adf7 100644 --- a/xla/service/hlo_unstacker_test.cc +++ b/xla/service/hlo_unstacker_test.cc @@ -151,6 +151,112 @@ TEST_F(UnstackerTest, UnstackLoopSingleFusionUser2) { std::nullopt)); } +TEST_F(UnstackerTest, UnstackLoopSingleFusionUserNoBitcast) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + ROOT %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[1,128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + bitcast.102 = s8[128,128] bitcast(s8[1,128,128] %fusion.67830) + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] bitcast.102), dim_labels=bf_io->bf + ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); + std::cout << module->ToString() << std::endl; + EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), + std::nullopt, false)); +} + +TEST_F(UnstackerTest, UnstackLoopSingleFusionUserNoBitcastKeepFused) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + ROOT %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[1,128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + bitcast.102 = s8[128,128] bitcast(s8[1,128,128] %fusion.67830) + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] bitcast.102), dim_labels=bf_io->bf + ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + auto unfuse = [](HloInstruction* instruction) { return false; }; + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, + HloUnstacker(unfuse).Run(module.get())); + std::cout << module->ToString() << std::endl; + EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetSliceCountInEntry(module.get()), 0); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), + std::nullopt, false)); +} + TEST_F(UnstackerTest, UnstackLoopSingleFusionUserDifferentLayout) { std::string hlo_string = R"( HloModule SimpleLoop From d7bb04e2e972b060b409ea51ad9521f74c9e9f81 Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Wed, 31 Jul 2024 20:51:04 -0700 Subject: [PATCH 352/376] Update to use the local hardware ID to get device description. PiperOrigin-RevId: 658248021 --- xla/pjrt/gpu/se_gpu_pjrt_client.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index cf60a8f6072c03..27d3f18dbc72cb 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1013,7 +1013,8 @@ absl::StatusOr BuildDistributedDevices( ordinal_and_device.second->executor()->GetPlatform(); TF_ASSIGN_OR_RETURN( std::unique_ptr desc, - platform->DescriptionForDevice(ordinal_and_device.first)); + platform->DescriptionForDevice( + ordinal_and_device.second->local_hardware_id().value())); DeviceProto* device_proto = local_topology.add_devices(); device_proto->set_local_device_ordinal(ordinal_and_device.first); device_proto->set_name(desc->name()); From 0ee47bd5bd2ba3cfa9aebc300b99fd5aca7ecd76 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 22:34:17 -0700 Subject: [PATCH 353/376] Automated Code Change PiperOrigin-RevId: 658269432 --- xla/python/profiler/internal/python_hooks.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/python/profiler/internal/python_hooks.h b/xla/python/profiler/internal/python_hooks.h index a9b502ef3b2e46..29e6b83dac1962 100644 --- a/xla/python/profiler/internal/python_hooks.h +++ b/xla/python/profiler/internal/python_hooks.h @@ -77,7 +77,7 @@ struct PythonTraceEntry { Py_XDECREF(m_module); } - PythonTraceEntry(PythonTraceEntry&& other) { + PythonTraceEntry(PythonTraceEntry&& other) noexcept { start_time_ns = other.start_time_ns; end_time_ns = other.end_time_ns; co_firstlineno = other.co_firstlineno; From 09cfa285024c61140e2c530ef10175fab83aec5f Mon Sep 17 00:00:00 2001 From: Changhui Lin Date: Wed, 31 Jul 2024 22:51:53 -0700 Subject: [PATCH 354/376] Use the `device_ordinal` from the `run_options` if it is provided. This is the ordinal of the logical devices (e.g., virtual GPUs). PiperOrigin-RevId: 658273190 --- xla/service/gpu/gpu_executable.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index bf9774711fcfd6..25fe6510f0baa8 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -841,9 +841,14 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( TF_ASSIGN_OR_RETURN(globals, ResolveConstantGlobals(run_options->stream())); } - auto device_ordinal = executor->device_ordinal(); + // Use the `device_ordinal` from the `run_options` if it is provided. This is + // the ordinal of the logical devices (e.g., virtual GPUs). If it is not + // provided, the ordinals of the logical and physical devices are the same. + const int device_ordinal = run_options->device_ordinal() != -1 + ? run_options->device_ordinal() + : executor->device_ordinal(); ExecutionOutput result(/*on_device_shape=*/output_shape_, memory_allocator, - device_ordinal); + device_ordinal, executor->device_ordinal()); TF_ASSIGN_OR_RETURN( BufferAllocations buffer_allocations, From fedb90a29e84a7e15dd0680417d5d37c72442c76 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 31 Jul 2024 23:57:55 -0700 Subject: [PATCH 355/376] Automated Code Change PiperOrigin-RevId: 658288373 --- xla/service/spmd/dot_handler.cc | 2 +- xla/service/spmd/spmd_partitioner.cc | 5 +++++ xla/service/spmd/spmd_partitioner_util.h | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/xla/service/spmd/dot_handler.cc b/xla/service/spmd/dot_handler.cc index 87db7f723d6432..16ace2aa39e86c 100644 --- a/xla/service/spmd/dot_handler.cc +++ b/xla/service/spmd/dot_handler.cc @@ -1899,7 +1899,7 @@ absl::StatusOr PartitionBaseCase( has_reshape_operand(lhs) ? lhs.hlo()->operand(0) : lhs.hlo(); auto rhs_operand = has_reshape_operand(rhs) ? rhs.hlo()->operand(0) : rhs.hlo(); - for (auto loop : *windowed_dot_general_loops) { + for (const auto& loop : *windowed_dot_general_loops) { if (loop.while_loop->while_body()->name().find( "windowed_dot_general_body_ag") == 0) { auto cm_lhs = loop.while_loop->operand(0)->operand(0); diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index c3fc8b1ab31c0a..303e79c29352a1 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -3316,6 +3316,7 @@ absl::Status SpmdPartitioningVisitor::HandleSingleDevice( auto param = true_b.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, operand_shape, "true_branch_param")); std::vector new_operands; + new_operands.reserve(operands.size()); for (int64_t i = 0; i < operands.size(); ++i) { new_operands.push_back(true_b.AddInstruction( HloInstruction::CreateGetTupleElement(*operand_shapes[i], param, i))); @@ -4129,6 +4130,7 @@ absl::Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) { if (hlo->sharding().IsManual()) { auto clone_from_original = [&](const HloSharding& shared_sharding) { std::vector new_operands; + new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { new_operands.push_back( GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo()); @@ -4310,6 +4312,7 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { } auto clone_from_original = [&](const HloSharding& shared_sharding) { std::vector new_operands; + new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { new_operands.push_back( GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo()); @@ -4340,6 +4343,7 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { TF_RET_CHECK(!hlo->sharding().IsTileMaximal()); // Replicate the operands and run partitioned Rng on all devices. std::vector new_operands; + new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) .Reshard(HloSharding::Replicate()) @@ -4659,6 +4663,7 @@ absl::Status SpmdPartitioningVisitor::HandleSelectAndScatter( absl::Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) { std::vector new_operands; + new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { new_operands.push_back( GetPartitionedHlo(hlo->operand(i)) diff --git a/xla/service/spmd/spmd_partitioner_util.h b/xla/service/spmd/spmd_partitioner_util.h index 65b5d0134b4e39..a982c3edf1e8db 100644 --- a/xla/service/spmd/spmd_partitioner_util.h +++ b/xla/service/spmd/spmd_partitioner_util.h @@ -84,6 +84,7 @@ HloInstruction* CreateConstantBase(const Shape& shape, Literal value, T* b, PrimitiveType)) { if (shape.IsTuple()) { std::vector elements; + elements.reserve(ShapeUtil::TupleElementCount(shape)); for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { elements.push_back( CreateConstantBase(ShapeUtil::GetTupleElementShape(shape, i), From 210e1f5be1ea057bb8f4d3cffd3667ec608d6bc7 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Thu, 1 Aug 2024 00:45:18 -0700 Subject: [PATCH 356/376] Add IndexingMapAttr to ApplyIndexingOp PiperOrigin-RevId: 658299867 --- .../gpu/fusions/concatenate_mlir_test.cc | 14 +- ...in_place_dynamic_update_slice_mlir_test.cc | 16 +- xla/service/gpu/fusions/loop_mlir_test.cc | 8 +- .../mlir/elemental_hlo_to_mlir_test.cc | 109 +++++------ .../gpu/fusions/mlir/ir/xla_gpu_ops.cc | 178 ++++++++---------- xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h | 6 +- .../gpu/fusions/mlir/ir/xla_gpu_ops.td | 11 +- .../gpu/fusions/mlir/tests/canonicalize.mlir | 95 +++++----- .../fusions/mlir/tests/flatten_tensors.mlir | 36 ++-- .../gpu/fusions/mlir/tests/invalid.mlir | 29 ++- .../gpu/fusions/mlir/tests/lower_tensors.mlir | 6 +- xla/service/gpu/fusions/mlir/tests/ops.mlir | 58 ++++-- .../fusions/mlir/tests/optimize_loops.mlir | 31 +-- .../fusions/mlir/tests/simplify_affine.mlir | 17 +- .../fusions/mlir/tests/simplify_arith.mlir | 16 +- .../mlir/tests/vectorize_loads_stores.mlir | 46 +++-- .../gpu/fusions/reduction_mlir_test.cc | 8 +- xla/service/gpu/fusions/scatter_mlir_test.cc | 4 +- .../triton_fusion_emitter_device_test.cc | 10 +- 19 files changed, 376 insertions(+), 322 deletions(-) diff --git a/xla/service/gpu/fusions/concatenate_mlir_test.cc b/xla/service/gpu/fusions/concatenate_mlir_test.cc index 92aff949ace0b6..9323cf7dca8214 100644 --- a/xla/service/gpu/fusions/concatenate_mlir_test.cc +++ b/xla/service/gpu/fusions/concatenate_mlir_test.cc @@ -102,9 +102,9 @@ TEST_F(MlirConcatenateFusionTest, StandAloneConcatenate) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0)> - // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 200)> - // CHECK-DAG: #[[MAP_3:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 600)> + // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0) + // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0 + 200) + // CHECK-DAG: #[[MAP_3:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0 + 600) // CHECK-LABEL: fused_computation // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}}, @@ -152,7 +152,7 @@ TEST_F(MlirConcatenateFusionTest, PrologueEpilogue) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 + 64)> + // CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 64) // CHECK-LABEL: fused_computation // CHECK-DAG: %[[C_63:.*]] = arith.constant 63 @@ -254,9 +254,9 @@ TEST_F(MlirConcatenateFusionTest, Vectorization) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: affine_map<(d0, d1) -> (d1 * 128 + d0)> - // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0)> - // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002)> + // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0) + // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0) + // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002) // CHECK-LABEL: fused_computation // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc index b0da3ef5c04532..2456f2ee63cc4d 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc @@ -100,8 +100,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, SimpleDUS) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0) -> (d0 floordiv 6)> - // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0) -> (d0 mod 6)> + // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 6), domain: d0 in [0, 29] + // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 6), domain: d0 in [0, 29] // CHECK: func.func @fused_computation // CHECK-SAME: %arg0: tensor<20x30xf32> // CHECK-SAME: %arg1: tensor<5x6xf32> @@ -112,8 +112,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, SimpleDUS) { // CHECK-DAG: %[[C_15:.*]] = arith.constant 15 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x - // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 29]) - // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 29]) + // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]]) + // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]]) // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] @@ -151,8 +151,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OutOfBoundDUS) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> - // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0) -> (d0 mod 3)> + // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 3), domain: d0 in [0, 5] + // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 3), domain: d0 in [0, 5] // CHECK: func.func @fused_computation // CHECK-SAME: %arg0: tensor<7x8xf32> // CHECK-SAME: %arg1: tensor<2x3xf32> @@ -162,8 +162,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OutOfBoundDUS) { // CHECK-DAG: %[[C_5:.*]] = arith.constant 5 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x - // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 5]) - // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 5]) + // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]]) + // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]]) // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] diff --git a/xla/service/gpu/fusions/loop_mlir_test.cc b/xla/service/gpu/fusions/loop_mlir_test.cc index 357ef652985b43..93f374838898aa 100644 --- a/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/xla/service/gpu/fusions/loop_mlir_test.cc @@ -196,10 +196,10 @@ TEST_F(MlirLoopFusionTest, Constant_Broadcast) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1 * 1024 + d0)> - // CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) floordiv 768)> - // CHECK: #[[MAP2:.*]] = affine_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 48) mod 16)> - // CHECK: #[[MAP3:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48)> + // CHECK-DAG: #[[MAP0:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 1024 + d0) + // CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 1024 + d0) floordiv 768) + // CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 48) mod 16) + // CHECK-DAG: #[[MAP3:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48) // CHECK: func.func @fused_computation(%[[ARG0:.*]]: tensor<2x16x48xbf16> // CHECK: %[[UPPER_BOUND:.*]] = arith.constant 1535 : index // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index d7bbbb0bd34c3c..c4a92e6c33379c 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -235,10 +235,10 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) - // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 4)> - // CHECK-SAME: (%[[Y]] in [0, 2]) - // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0 - 3)> - // CHECK-SAME: (%[[Z]] in [0, 7])[%[[I]] in [0, 6]] + // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 4), domain: d0 in [0, 2]>(%[[Y]]) + // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 3), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 6]>(%[[Z]])[%[[I]]] // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] // CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]] // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]], @@ -285,8 +285,8 @@ TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { // If symbol rescaling wasn't working we would have a // `s0 floordiv ` in the map: // CHECK: %[[K:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[X]] in [0, 18])[%[[I]] in [0, 3]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + // CHECK-SAME: d0 in [0, 18], s0 in [0, 3]>(%[[X]])[%[[I]]] // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] )")); @@ -433,7 +433,7 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]>(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -445,11 +445,9 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)> - // CHECK-SAME: (%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]>(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> (d0 - 4)> - // CHECK-SAME: (%[[Y]] in [4, 7]) + // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7]>(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -477,7 +475,7 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]>(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -489,11 +487,9 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)> - // CHECK-SAME: (%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]>(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> (d0 - 4)> - // CHECK-SAME: (%[[Y]] in [4, 7]) + // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7]>(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -810,11 +806,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 5])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 7])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -856,11 +852,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[W]] in [0, 2])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + // CHECK-SAME: d0 in [0, 2], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + // CHECK-SAME: d0 in [0, 3], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -903,21 +899,21 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[W]] in [0, 7])[%[[X]] in [0, 2]] + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), domain: d0 in [0, 7], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[H]] in [0, 11])[%[[Y]] in [0, 4]] + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), domain: d0 in [0, 11], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 1)> - // CHECK-SAME: (%[[W]] in [0, 7])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 1), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 2)> - // CHECK-SAME: (%[[H]] in [0, 11])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 2), + // CHECK-SAME: d0 in [0, 11], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -957,17 +953,17 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[W]] in [0, 12])[%[[X]] in [0, 2]] + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) mod 2), domain: d0 in [0, 12], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[H]] in [0, 18])[%[[Y]] in [0, 4]] + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) mod 2), domain: d0 in [0, 18], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)> - // CHECK-SAME: (%[[W]] in [0, 12])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) floordiv 2), + // CHECK-SAME: d0 in [0, 12], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)> - // CHECK-SAME: (%[[H]] in [0, 18])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) floordiv 2), + // CHECK-SAME: d0 in [0, 18], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1009,11 +1005,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)> - // CHECK-SAME: (%[[W]] in [0, 3])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2), + // CHECK-SAME: d0 in [0, 3], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)> - // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2), + // CHECK-SAME: d0 in [0, 3], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1055,17 +1051,14 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 5]) - // CHECK-SAME: [%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 7]) - // CHECK-SAME: [%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0)> - // CHECK-SAME: (%[[O]] in [0, 15]) - // CHECK-SAME: [%[[I]] in [0, 1]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0), + // CHECK-SAME: d0 in [0, 15], s0 in [0, 1]>(%[[O]])[%[[I]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1109,13 +1102,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { // CHECK-NEXT: %[[R3:.+]] = scf.for %[[G:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A2]]) -> (f32) { // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 5]) - // CHECK-SAME: [%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 7]) - // CHECK-SAME: [%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1581,8 +1572,8 @@ TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)> - // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9]) + // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1), + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]>(%[[X]], %[[Y]]) // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] // CHECK: return %[[A]], %[[B]] )")); @@ -1605,8 +1596,8 @@ TEST_F(ElementalHloToMlirTest, NestedTuple) { // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} // CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0 // CHECK: %[[IDX:.*]] = - // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)> - // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9]) + // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1), + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]>(%[[X]], %[[Y]]) // CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1 // CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]]) // CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]] diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc index dfa4d056a80bda..57c1da7be8558b 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -121,62 +121,54 @@ void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, } void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, - ValueRange operands, - const IndexingMap& indexing_map) { - build(builder, result, operands, indexing_map.GetAffineMap(), - indexing_map.GetDimVars(), indexing_map.GetRangeVars()); + ValueRange operands, IndexingMap indexing_map) { + SmallVector result_types(indexing_map.GetAffineMap().getNumResults(), + builder.getIndexType()); + // ApplyIndexingOp cannot have any constraints. It may be better to enforce + // callers to do this, but for now this follows the previous behavior. + indexing_map.ClearConstraints(); + IndexingMapAttr indexing_map_attr = + IndexingMapAttr::get(builder.getContext(), indexing_map); + build(builder, result, result_types, operands, indexing_map_attr); } void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, ValueRange operands, AffineMap affine_map, ArrayRef dim_vars, ArrayRef range_vars) { - SmallVector lower_bounds, upper_bounds; - for (const DimVar& dim_var : dim_vars) { - lower_bounds.push_back(dim_var.bounds.lower); - upper_bounds.push_back(dim_var.bounds.upper); - } - for (const RangeVar& range_var : range_vars) { - lower_bounds.push_back(range_var.range.lower); - upper_bounds.push_back(range_var.range.upper); - } - build(builder, result, operands, affine_map, lower_bounds, upper_bounds); + IndexingMap indexing_map(affine_map, dim_vars, range_vars, {}); + build(builder, result, operands, indexing_map); } void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, ValueRange operands, AffineMap affine_map, ArrayRef lower_bounds, ArrayRef upper_bounds) { - SmallVector result_types(affine_map.getNumResults(), - builder.getIndexType()); - build(builder, result, result_types, operands, affine_map, lower_bounds, - upper_bounds); + unsigned num_dimensions = affine_map.getNumDims(); + std::vector dim_vars; + dim_vars.reserve(num_dimensions); + for (unsigned id = 0; id < num_dimensions; ++id) { + dim_vars.push_back(DimVar{Interval{lower_bounds[id], upper_bounds[id]}}); + } + unsigned num_symbols = affine_map.getNumSymbols(); + std::vector range_vars; + range_vars.reserve(num_symbols); + for (unsigned id = num_dimensions; id < num_symbols + num_dimensions; ++id) { + range_vars.push_back( + RangeVar{Interval{lower_bounds[id], upper_bounds[id]}}); + } + IndexingMap indexing_map(affine_map, std::move(dim_vars), + std::move(range_vars), /*rt_vars=*/{}); + build(builder, result, operands, indexing_map); } -// Parser a comma-separated list of type %operand in [lower_bound, upper_bound]. -// Adds the parsed elements to the provided containers. -mlir::ParseResult parseOperandsWithBoundsList( +// Parses a comma-separated list of operands, ex: %d1, %d2. +mlir::ParseResult parseOperands( mlir::OpAsmParser& parser, - SmallVector* operands, - SmallVector* lower_bounds, - SmallVector* upper_bounds) { - int64_t lower_bound, upper_bound; + SmallVector* operands) { mlir::OpAsmParser::UnresolvedOperand operand; - if (parser.parseCommaSeparatedList([&]() { - if (parser.parseOperand(operand) || parser.parseKeyword("in") || - parser.parseLSquare() || parser.parseInteger(lower_bound) || - parser.parseComma() || parser.parseInteger(upper_bound) || - parser.parseRSquare()) { - return failure(); - } - operands->push_back(operand); - lower_bounds->push_back(lower_bound); - upper_bounds->push_back(upper_bound); - return success(); - })) { - return failure(); - } - return success(); + return parser.parseCommaSeparatedList( + [&]() { return parser.parseOperand(operands->emplace_back()); }); } mlir::ParseResult ApplyIndexingOp::parse(mlir::OpAsmParser& parser, @@ -184,24 +176,21 @@ mlir::ParseResult ApplyIndexingOp::parse(mlir::OpAsmParser& parser, mlir::Builder& builder = parser.getBuilder(); auto index_type = builder.getIndexType(); - mlir::AffineMapAttr affine_map_attr; - if (parser.parseAttribute(affine_map_attr, "map", result.attributes)) { + IndexingMapAttr indexing_map_attr; + if (parser.parseAttribute(indexing_map_attr, "indexing_map_attr", + result.attributes)) { return failure(); } SmallVector operands; SmallVector lower_bounds, upper_bounds; if (succeeded(parser.parseOptionalLParen())) { - if (parseOperandsWithBoundsList(parser, &operands, &lower_bounds, - &upper_bounds) || - parser.parseRParen()) { + if (parseOperands(parser, &operands) || parser.parseRParen()) { return failure(); } } if (succeeded(parser.parseOptionalLSquare())) { - if (parseOperandsWithBoundsList(parser, &operands, &lower_bounds, - &upper_bounds) || - parser.parseRSquare()) { + if (parseOperands(parser, &operands) || parser.parseRSquare()) { return failure(); } } @@ -209,85 +198,78 @@ mlir::ParseResult ApplyIndexingOp::parse(mlir::OpAsmParser& parser, parser.parseOptionalAttrDict(result.attributes)) { return failure(); } - result.addAttribute("lower_bounds", - builder.getDenseI64ArrayAttr(lower_bounds)); - result.addAttribute("upper_bounds", - builder.getDenseI64ArrayAttr(upper_bounds)); - - auto map = affine_map_attr.getAffineMap(); + auto map = indexing_map_attr.getMap(); result.addTypes(SmallVector(map.getNumResults(), index_type)); return success(); } void ApplyIndexingOp::print(mlir::OpAsmPrinter& p) { - mlir::AffineMapAttr affine_map_attr = getMapAttr(); - AffineMap affine_map = affine_map_attr.getAffineMap(); - p << " " << affine_map_attr; + AffineMap affine_map = getIndexingMapAttr().getMap(); + p << " " << getIndexingMapAttr(); - auto lower_bounds = getLowerBounds(); - auto upper_bounds = getUpperBounds(); auto operands = getOperands(); unsigned num_dimensions = affine_map.getNumDims(); if (num_dimensions > 0) { p << '('; - for (int dim_id = 0; dim_id < num_dimensions; ++dim_id) { - p << operands[dim_id] << " in " << '[' << lower_bounds[dim_id] << ", " - << upper_bounds[dim_id] << ']'; - if (dim_id != num_dimensions - 1) { - p << ", "; - } - } + auto dimension_operands = operands.slice(0, num_dimensions); + llvm::interleaveComma(dimension_operands, p); p << ')'; } + unsigned num_symbols = affine_map.getNumSymbols(); if (num_symbols > 0) { p << '['; - for (int symbol_id = 0; symbol_id < num_symbols; ++symbol_id) { - unsigned operand_id = num_dimensions + symbol_id; - p << operands[operand_id] << " in " << '[' << lower_bounds[operand_id] - << ", " << upper_bounds[operand_id] << ']'; - if (symbol_id != num_symbols - 1) { - p << ", "; - } - } + auto symbol_operands = operands.slice(num_dimensions, num_symbols); + llvm::interleaveComma(symbol_operands, p); p << ']'; } - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ - "map", "lower_bounds", "upper_bounds"}); + + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"indexing_map_attr"}); } LogicalResult ApplyIndexingOp::verify() { - auto affine_map = getMapAttr().getAffineMap(); + auto affine_map = getIndexingMapAttr().getMap(); unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols(); - if (getOperands().size() != num_variables || - getLowerBounds().size() != num_variables || - getUpperBounds().size() != num_variables) { + if (getOperands().size() != num_variables) { return emitOpError( - "operand, lower_bounds, upper_bounds count and affine map dimension " - "and symbol count must match"); + "operand count must match the number of dimensions and symbols in the " + "affine map"); + } + if (!getIndexingMapAttr().getConstraints().empty()) { + return emitOpError("apply indexing op cannot have any constraints"); } return success(); } -IndexingMap ApplyIndexingOp::getIndexingMap() { - auto lower_bounds = getLowerBounds(); - auto upper_bounds = getUpperBounds(); +llvm::SmallVector ApplyIndexingOp::getLowerBounds() { + SmallVector lower_bounds; + lower_bounds.reserve(getNumOperands()); + for (const auto& dim_var : getIndexingMapAttr().getDimVars()) { + lower_bounds.push_back(dim_var.bounds.lower); + } + for (const auto& range_var : getIndexingMapAttr().getRangeVars()) { + lower_bounds.push_back(range_var.range.lower); + } + return lower_bounds; +} - AffineMap affine_map = getAffineMap(); - unsigned num_dimensions = affine_map.getNumDims(); - std::vector dim_vars; - dim_vars.reserve(num_dimensions); - for (unsigned id = 0; id < num_dimensions; ++id) { - dim_vars.push_back(DimVar{Interval{lower_bounds[id], upper_bounds[id]}}); +llvm::SmallVector ApplyIndexingOp::getUpperBounds() { + SmallVector upper_bounds; + upper_bounds.reserve(getNumOperands()); + for (const auto& dim_var : getIndexingMapAttr().getDimVars()) { + upper_bounds.push_back(dim_var.bounds.upper); } - unsigned num_symbols = affine_map.getNumSymbols(); - std::vector range_vars; - range_vars.reserve(num_symbols); - for (unsigned id = num_dimensions; id < num_symbols + num_dimensions; ++id) { - range_vars.push_back( - RangeVar{Interval{lower_bounds[id], upper_bounds[id]}}); + for (const auto& range_var : getIndexingMapAttr().getRangeVars()) { + upper_bounds.push_back(range_var.range.upper); } - return IndexingMap(affine_map, std::move(dim_vars), std::move(range_vars), + return upper_bounds; +} + +IndexingMap ApplyIndexingOp::getIndexingMap() { + return IndexingMap(getIndexingMapAttr().getMap(), + getIndexingMapAttr().getDimVars(), + getIndexingMapAttr().getRangeVars(), /*rt_vars=*/{}); } diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h index f43786f4fde0ac..6462697a80b9ed 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h @@ -29,11 +29,11 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" // IWYU pragma: keep #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc" -#define GET_OP_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h.inc" -#undef GET_OP_CLASSES #define GET_ATTRDEF_CLASSES #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc" #undef GET_ATTRDEF_CLASSES +#define GET_OP_CLASSES +#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h.inc" +#undef GET_OP_CLASSES #endif // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td index c05f843465427c..ed70f2ced60fcc 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td" +include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td" class XLAGPU_Op traits = []> : Op { @@ -242,16 +243,14 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> { ``` }]; let arguments = (ins Variadic:$operands, - AffineMapAttr:$map, - DenseI64ArrayAttr:$lower_bounds, - DenseI64ArrayAttr:$upper_bounds); + XLAGPU_IndexingMapAttr:$indexing_map_attr); let results = (outs Variadic); let builders = [ OpBuilder<(ins "mlir::ValueRange":$dims, "mlir::ValueRange":$symbols, "const IndexingMap&":$indexing_map)>, OpBuilder<(ins "mlir::ValueRange":$operands, - "const IndexingMap&":$indexing_map)>, + "IndexingMap":$indexing_map)>, OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map, "llvm::ArrayRef":$dim_vars, "llvm::ArrayRef":$range_vars)>, @@ -263,8 +262,10 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> { let extraClassDeclaration = [{ // Returns the indexing map constructed from affine_map and the bounds. xla::gpu::IndexingMap getIndexingMap(); + llvm::SmallVector getLowerBounds(); + llvm::SmallVector getUpperBounds(); // Extracts the affine map from the attribute. - mlir::AffineMap getAffineMap() { return getMapAttr().getAffineMap(); } + mlir::AffineMap getAffineMap() { return getIndexingMapAttr().getMap(); } }]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; diff --git a/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir b/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir index 17b0f8d9b45c88..34065f9c19d53a 100644 --- a/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir +++ b/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir @@ -1,27 +1,29 @@ // RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s -#map0 = affine_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)> +#map0 = #xla_gpu.indexing_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), + domain: s0 in [-10, 10], s1 in [0, 2]> func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [-10, 10], %s1 in [0, 2]] + %0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1, s0 mod 2)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<()[s0] -> (s0 + 1, s0 mod 2), +// CHECK-SAME: domain: s0 in [-10, 10]> // CHECK-LABEL: func.func @simplify_apply_indexing // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]] in [-10, 10]] +// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]]] // ----- -#map0 = affine_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2)> +#map0 = #xla_gpu.indexing_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), + domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]> func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, %d2: index, %s0: index, %s1: index) -> (index, index, index) { - %0:3 = xla_gpu.apply_indexing #map0 - (%d0 in [0, 1], %d1 in [0, 2], %d2 in [0, 3]) - [%s0 in [-11, 11], %s1 in [0, 3]] + %0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1] func.return %0#0, %0#1, %0#2 : index, index, index } -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1), +// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 3], s0 in [-11, 11]> // CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: index, @@ -30,18 +32,19 @@ func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, // CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: index, // CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: index) // CHECK: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG_0]] in [0, 1], %[[ARG_2]] in [0, 3]) -// CHECK-SAME: [%[[ARG_3]] in [-11, 11]] +// CHECK-SAME: (%[[ARG_0]], %[[ARG_2]]) +// CHECK-SAME: [%[[ARG_3]]] // ----- -#map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0)> +#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), + domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]> func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) -> (index, index, index, index, index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]] + %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index } -// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), // CHECK-LABEL: func.func @fold_indexing_map_results // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) @@ -54,42 +57,45 @@ func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) // ----- -#map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)> +#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0), + domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]> func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]] + %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#2 : index } -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2), +// CHECK-SAME: domain: d0 in [0, 2]> // CHECK-LABEL: func.func @remove_unused_results // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]] in [0, 2]) +// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]]) // CHECK: return %[[NEW_RESULT]] // ----- -#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)> +#map0 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3), + domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]> func.func @fold_operands(%d0: index) -> index { %d1 = arith.constant 1 : index %s0 = arith.constant 2 : index %s1 = arith.constant 3 : index - %0 = xla_gpu.apply_indexing #map0 (%d0 in [0, 10], %d1 in [0, 5]) - [%s0 in [-10, 10], %s1 in [0, 4]] + %0 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0, %s1] func.return %0 : index } -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 3)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 3), +// CHECK-SAME: domain: d0 in [0, 10]> // CHECK-LABEL: func.func @fold_operands // CHECK-SAME: %[[ARG_0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 10]) +// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]]) // ----- func.func @fold_operands_and_results(%arg0: index, %arg1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (0, d1)> - (%arg0 in [0, 4], %arg1 in [0, 5]) + %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (0, d1), + domain: d0 in [0, 4], d1 in [0, 5]>(%arg0, %arg1) return %0#0, %0#1 : index, index } @@ -101,50 +107,53 @@ func.func @fold_operands_and_results(%arg0: index, %arg1: index) // ----- func.func @fold_sequence(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 5], %arg1 in [0, 4]) - %1 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 100 + 42)> - (%0 in [0, 10000]) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), + domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 100 + 42), + domain: d0 in [0, 10000]>(%0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 + 42)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42), +// CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4]> // CHECK-LABEL: func.func @fold_sequence // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4]) +// CHECK-SAME: (%[[ARG0]], %[[ARG1]]) // ----- func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 5], %arg1 in [0, 4]) - %1 = xla_gpu.apply_indexing affine_map<()[s0] -> (s0 mod 100 + 42)> - [%0 in [0, 10000]] + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), + domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<()[s0] -> (s0 mod 100 + 42), + domain: s0 in [0, 10000]>(%0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 + 42)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42), +// CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4]> // CHECK-LABEL: func.func @fold_sequence_sym // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4]) +// CHECK-SAME: (%[[ARG0]], %[[ARG1]]) // ----- func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 5], %arg1 in [0, 4]) - %1 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg1 in [0, 4], %0 in [0, 10000]) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), + domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), + domain: d0 in [0, 4], d1 in [0, 10000]>(%arg1, %0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), +// CHECK-SAME: domain: d0 in [0, 4], d1 in [0, 5]> // CHECK-LABEL: func.func @fold_sequence_shared_operands // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG1]] in [0, 4], %[[ARG0]] in [0, 5]) +// CHECK-SAME: (%[[ARG1]], %[[ARG0]]) // ----- diff --git a/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir b/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir index ee2c2ae9e9553d..21a8dc2a0b7e79 100644 --- a/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir +++ b/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir @@ -8,13 +8,13 @@ func.func @tensor_extract( : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1 * 2 + d0)> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), +// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 2]> // CHECK-LABEL: func.func @tensor_extract( // CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>, // CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index) -> f32 { -// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]] -// CHECK-SAME: in [0, 1], %[[J]] in [0, 2]) +// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]], %[[J]]) // CHECK: tensor.extract %[[SRC]][%[[INDEX]]] : tensor<6xf32> // ----- @@ -47,13 +47,14 @@ func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) } return %ret : tensor<2x4xf32> } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 4 + d1), +// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 3]> // CHECK-LABEL: func.func @atomic_rmw( // CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index, // CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> { // CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[I]] in [0, 1], %[[J]] in [0, 3]) +// CHECK-SAME: (%[[I]], %[[J]]) // CHECK: xla_gpu.atomic_rmw %[[TENSOR]][%[[INDEX]]] : tensor<8xf32> // ----- @@ -74,8 +75,10 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32 } -// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 + 1024)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 32 + 5)> +// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024), +// CHECK-SAME: domain: d0 in [0, 1023]> +// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5), +// CHECK-SAME: domain: d0 in [0, 63]> // CHECK-LABEL: func.func @for_loop( // CHECK-SAME: %[[T0:.*]]: tensor<32768xf32>, // CHECK-SAME: %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) { @@ -87,17 +90,20 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) // CHECK: %[[FOR:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[C64]] // CHECK-SAME: step %[[C32]] // CHECK-SAME: iter_args(%[[T0_:.*]] = %[[T0]], %[[T1_:.*]] = %[[T1]]) -// CHECK: %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]] in [0, 1023]) +// CHECK: %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]]) // CHECK: %[[UPD0:.*]] = tensor.insert %[[F32]] into %[[T0_]][%[[IND0]]] -// CHECK: %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]] in [0, 63]) +// CHECK: %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]]) // CHECK: %[[UPD1:.*]] = tensor.insert %[[F32]] into %[[T1_]][%[[IND1]]] // CHECK: scf.yield %[[UPD0]], %[[UPD1]] : tensor<32768xf32>, tensor<2048xf32> // ----- -#map = affine_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36)> -#map1 = affine_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4)> -#map2 = affine_map<(d0, d1) -> ((d1 * 128 + d0) mod 9)> +#map = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36), + domain: d0 in [0, 127], d1 in [0, 393749]> +#map1 = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), + domain: d0 in [0, 127], d1 in [0, 393749]> +#map2 = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 9), + domain: d0 in [0, 127], d1 in [0, 393749]> func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>) -> tensor<4000x4x9xf32> { @@ -105,13 +111,13 @@ func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %c3999 = arith.constant 3999 : index %th_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %bl_x = gpu.block_id x {xla.range = [0 : index, 393749 : index]} - %0 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 393749]) + %0 = xla_gpu.apply_indexing #map(%th_x, %bl_x) %extracted = tensor.extract %arg1[%0, %c0] : tensor<1400x1xi32> %1 = arith.index_cast %extracted : i32 to index %2 = arith.cmpi ule, %1, %c3999 : index %3 = scf.if %2 -> (tensor<4000x4x9xf32>) { - %4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 393749]) - %5 = xla_gpu.apply_indexing #map2(%th_x in [0, 127], %bl_x in [0, 393749]) + %4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x) + %5 = xla_gpu.apply_indexing #map2(%th_x, %bl_x) %elem = tensor.extract %arg2[%0, %c0, %4, %5] : tensor<1400x1x4x9xf32> %atomic_rmw = xla_gpu.atomic_rmw %arg3[%1, %4, %5] : tensor<4000x4x9xf32> { ^bb0(%arg4: f32): diff --git a/xla/service/gpu/fusions/mlir/tests/invalid.mlir b/xla/service/gpu/fusions/mlir/tests/invalid.mlir index fbef7c049db487..11b54c81214940 100644 --- a/xla/service/gpu/fusions/mlir/tests/invalid.mlir +++ b/xla/service/gpu/fusions/mlir/tests/invalid.mlir @@ -1,8 +1,31 @@ // RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics -#map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> +#map0 = #xla_gpu.indexing_map< + (d0, d1)[s0] -> (d0, d1 + s0), + domain: + d0 in [1, 2], + d1 in [5, 8], + s0 in [0, 32] +> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { - // expected-error @+1 {{operand, lower_bounds, upper_bounds count and affine map dimension and symbol count must match}} - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2]) + // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}} + %0:2 = xla_gpu.apply_indexing #map0 (%d0) + func.return %0#0, %0#1 : index, index +} + +// ----- + +#map0 = #xla_gpu.indexing_map< + (d0, d1)[s0] -> (d0, d1 + s0), + domain: + d0 in [1, 2], + d1 in [5, 8], + s0 in [0, 32], + d0 mod 2 in [0, 1], + d0 + s0 in [1, 10] +> +func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) { + // expected-error @+1 {{apply indexing op cannot have any constraints}} + %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1 : index, index } diff --git a/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir index 2125e6f4d70c8f..be8eb1eef94f8e 100644 --- a/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir +++ b/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir @@ -90,12 +90,12 @@ module { } } -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1 * 2 + d0)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), +// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 2]> // CHECK-LABEL: @layout( // CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, // CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index -// CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[X]] in [0, 1], %[[Y]] in [0, 2]) +// CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[X]], %[[Y]]) // CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64 // CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]] // CHECK: llvm.load %[[PTR]] diff --git a/xla/service/gpu/fusions/mlir/tests/ops.mlir b/xla/service/gpu/fusions/mlir/tests/ops.mlir index c7f15073b5e0ed..459c588dafe1d9 100644 --- a/xla/service/gpu/fusions/mlir/tests/ops.mlir +++ b/xla/service/gpu/fusions/mlir/tests/ops.mlir @@ -1,6 +1,6 @@ -// R-UN: mlir_fusions_opt %s --split-input-file | FileCheck %s +// RUN: mlir_fusions_opt %s --split-input-file | FileCheck %s // Verify the printed output can be parsed. -// RU-N: mlir_fusions_opt %s --split-input-file | mlir_fusions_opt --split-input-file | FileCheck %s +// RUN: mlir_fusions_opt %s --split-input-file | mlir_fusions_opt --split-input-file | FileCheck %s // Verify the generic form can be parsed. // RUN: mlir_fusions_opt %s --split-input-file --mlir-print-op-generic | mlir_fusions_opt --split-input-file | FileCheck %s @@ -56,41 +56,71 @@ func.func @caller(%a: f32, %b: f32) -> f32 { // ----- -#map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> +#map0 = #xla_gpu.indexing_map< +(d0, d1)[s0] -> (d0, d1 + s0), + domain: + d0 in [1, 2], + d1 in [5, 8], + s0 in [0, 32] +> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3])[%s0 in [2, 4]] + %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: (d0, d1)[s0] -> (d0, d1 + s0) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [1, 2] +// CHECK-SAME: d1 in [5, 8] +// CHECK-SAME: s0 in [0, 32] +// CHECK-SAME: > // CHECK-LABEL: @apply_indexing // CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index) // CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3])[%[[s0]] in [2, 4]] +// CHECK-SAME: (%[[d0]], %[[d1]])[%[[s0]]] // ----- -#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map0 = #xla_gpu.indexing_map< +(d0, d1) -> (d0, d1), + domain: + d0 in [0, 2], + d1 in [1, 3] +> func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3]) + %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1) func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: (d0, d1) -> (d0, d1) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [0, 2] +// CHECK-SAME: d1 in [1, 3] +// CHECK-SAME: > // CHECK-LABEL: @apply_indexing_no_symbols // CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index) // CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3]) +// CHECK-SAME: (%[[d0]], %[[d1]]) // ----- -#map0 = affine_map<()[s0] -> (s0, s0)> +#map0 = #xla_gpu.indexing_map< + ()[s0] -> (s0, s0), + domain: + s0 in [2, 4] +> func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [2, 4]] + %0:2 = xla_gpu.apply_indexing #map0 [%s0] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0, s0)> +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: ()[s0] -> (s0, s0) +// CHECK-SAME: domain: +// CHECK-SAME: s0 in [2, 4] +// CHECK-SAME: > // CHECK-LABEL: @apply_indexing_no_dims // CHECK: (%[[s0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]] in [2, 4]] +// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]]] diff --git a/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir b/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir index 6f903f3ace4748..3b33df2d3eb7a1 100644 --- a/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir +++ b/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir @@ -1,8 +1,11 @@ // RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s -#map = affine_map<(d0) -> (d0 floordiv 8)> -#map1 = affine_map<(d0) -> (d0 mod 8)> -#map2 = affine_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)> +#map = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 8), + domain: d0 in [0, 31]> +#map1 = #xla_gpu.indexing_map<(d0) -> (d0 mod 8), + domain: d0 in [0, 31]> +#map2 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), + domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]> module { func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>, @@ -21,23 +24,23 @@ module { %1 = arith.cmpi eq, %0, %c0 : index %2 = arith.divui %thread_id_x, %c32 : index %3 = arith.cmpi ult, %thread_id_x, %c8 : index - %4 = xla_gpu.apply_indexing #map(%block_id_x in [0, 31]) - %5 = xla_gpu.apply_indexing #map1(%block_id_x in [0, 31]) + %4 = xla_gpu.apply_indexing #map(%block_id_x) + %5 = xla_gpu.apply_indexing #map1(%block_id_x) %extracted = tensor.extract %arg2[%4, %5] : tensor<4x8xf32> %6 = arith.mulf %extracted, %cst : f32 %7 = arith.addf %6, %cst : f32 %8 = math.rsqrt %7 : f32 %9:2 = scf.for %arg7 = %c0 to %c8 step %c1 iter_args(%arg8 = %arg6, %arg9 = %cst) -> (tensor<4x8x4096xf32>, f32) { - %18 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %18 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %19 = vector.transfer_read %arg1[%18], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %20 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %20 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %21 = vector.transfer_read %arg3[%20], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %22 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %22 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %23 = vector.transfer_read %arg4[%4, %5, %22], %cst_1 {in_bounds = [true]} : tensor<4x8x4096xbf16>, vector<2xbf16> - %24 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %24 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %25 = vector.transfer_read %arg0[%4, %5, %24], %cst {in_bounds = [true]} : tensor<4x8x4096xf32>, vector<2xf32> %26:2 = scf.for %arg10 = %c0 to %c2 step %c1 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<4x8x4096xf32>, f32) { - %27 = xla_gpu.apply_indexing #map2(%arg10 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %27 = xla_gpu.apply_indexing #map2(%arg10, %thread_id_x)[%arg7] %28 = vector.extract %25[%arg10] : f32 from vector<2xf32> %29 = vector.extract %23[%arg10] : bf16 from vector<2xbf16> %30 = arith.extf %29 : bf16 to f32 @@ -124,7 +127,7 @@ module { } } -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 1)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_extract // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index @@ -151,7 +154,7 @@ module { %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> %cst0 = arith.constant 0.0 : f32 %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 2)>(%i in [0, 15]) + %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 15]>(%i) %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> %log = math.log %val : vector<2xf32> %add = arith.addf %log, %iter : vector<2xf32> @@ -161,8 +164,8 @@ module { } } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 2)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_transfer // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index diff --git a/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir b/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir index ec1a726da9db13..d51566a5b3dace 100644 --- a/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir +++ b/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir @@ -62,8 +62,9 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %0 = gpu.thread_id x %1 = gpu.block_id x scf.for %i = %c0 to %c4 step %c1 { - %2 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))> - [%1 in [0, 3071], %0 in [0, 127], %i in [0, 3]] + %2 = xla_gpu.apply_indexing + #xla_gpu.indexing_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4)), + domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3]>[%1, %0, %i] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.load %4 invariant : !llvm.ptr -> f32 @@ -91,8 +92,8 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt func.func @arg_ranges(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing - affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)> - [%arg0 in [0, 42], %arg1 in [0, 1000]] + #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100), + domain: s0 in [0, 42], s1 in [0, 1000]>[%arg0, %arg1] return %0 : index } @@ -105,8 +106,8 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing - affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)> - [%arg0 in [-10, 42], %arg1 in [0, 1000]] + #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1), + domain: s0 in [-10, 42], s1 in [0, 1000]>[%arg0, %arg1] return %0#0, %0#1 : index, index } @@ -123,8 +124,8 @@ func.func @order_summands(%arg1: index) { scf.for %arg2 = %c0 to %c4 step %c1 { scf.for %arg3 = %c0 to %c4 step %c1 { %0 = xla_gpu.apply_indexing - affine_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)> - [%arg2 in [0, 3], %arg1 in [0, 3], %arg3 in [0, 3]] + #xla_gpu.indexing_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10), + domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3]>[%arg2, %arg1, %arg3] "dummy.op"(%0) : (index) -> () } } diff --git a/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir b/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir index ee2e0ddbe29035..09c8901fab6000 100644 --- a/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir +++ b/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir @@ -247,7 +247,7 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { %c42_f32 = arith.constant 42.0 : f32 %loop = scf.for %i = %c0 to %c3 step %c1 iter_args(%in_ = %tensor) -> (tensor<100xf32>) { - %0 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 4)> (%i in [0, 9]) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 4), domain: d0 in [0, 9]>(%i) %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32> scf.yield %updated :tensor<100xf32> } @@ -261,8 +261,10 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { // ----- -#map = affine_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)> -#map1 = affine_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)> +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000), + domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3]> +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9), + domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3]> func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> { %c0 = arith.constant 0 : index @@ -276,10 +278,8 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, -> (tensor<2400000x9xf32>) { %2 = scf.for %j = %c0 to %c4 step %c1 iter_args(%arg5 = %arg3) -> (tensor<2400000x9xf32>) { - %3 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 575]) - [%i in [0, 73], %j in [0, 3]] - %4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 575]) - [%j in [0, 3]] + %3 = xla_gpu.apply_indexing #map(%th_x, %bl_x)[%i, %j] + %4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x)[%j] %inserted = tensor.insert %c42_f32 into %arg5[%3, %4] : tensor<2400000x9xf32> scf.yield %inserted : tensor<2400000x9xf32> @@ -288,5 +288,5 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, } return %0 : tensor<2400000x9xf32> } -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> ((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> ((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768), // CHECK-LABEL: func.func @refine_constraints_for_symbol diff --git a/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir b/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir index 1141d1581505ea..16e4498b0c5380 100644 --- a/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir +++ b/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir @@ -1,6 +1,7 @@ // RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file -xla-gpu-vectorize-loads-stores -canonicalize -cse | FileCheck %s -#map = affine_map<(d0)[s0] -> (d0 * 2 + s0)> +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> module { func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -10,7 +11,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -21,7 +22,7 @@ module { } } -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 * 2)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 63]> // CHECK-LABEL: @simple_read // CHECK-SAME: (%[[ARG0:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -29,7 +30,7 @@ module { // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #map(%[[I]] in [0, 63]) +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]) // CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] // CHECK-NEXT: vector.extract %[[V]][%[[J]]] @@ -66,7 +67,8 @@ module { // ----- -#map = affine_map<(d0)[s0] -> (d0 * 2 + s0 + 1)> +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 + 1), + domain: d0 in [0, 63], s0 in [0, 1]> module { func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -76,7 +78,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -92,7 +94,8 @@ module { // ----- -#map = affine_map<(d0)[s0] -> (d0 * 3 + s0)> +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 3 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> module { func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -102,7 +105,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -142,7 +145,8 @@ module { // ----- -#map = affine_map<(d0)[s0] -> (d0 + s0 * 2)> +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2), + domain: d0 in [0, 63], s0 in [0, 1]> module { func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -152,7 +156,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -299,7 +303,8 @@ module { // ----- -#map = affine_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)> +#map = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), + domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]> module { func.func @multiple(%arg0: tensor<32x4096xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<32xf32>, %arg3: tensor<32x4096xf32>, @@ -312,7 +317,7 @@ module { %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32> %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<32x4096xf32>, f32) { %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<32x4096xf32>, f32) { - %2 = xla_gpu.apply_indexing #map(%j in [0, 1], %arg4 in [0, 255])[%i in [0, 7]] + %2 = xla_gpu.apply_indexing #map(%j, %arg4)[%i] %extracted2 = tensor.extract %arg0[%i, %2] : tensor<32x4096xf32> %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16> %3 = arith.extf %extracted3 : bf16 to f32 @@ -328,12 +333,13 @@ module { } } -// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 * 2 + s0 * 512)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 * 512), +// CHECK-SAME: domain: d0 in [0, 255], s0 in [0, 7]> // CHECK-LABEL: @multiple // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]] in [0, 255])[%[[I]] in [0, 7]] +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]])[%[[I]]] // CHECK: %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]] // CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[BASE]]] // CHECK: %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>) @@ -350,7 +356,8 @@ module { // ----- -#map = affine_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0)> +#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> module { func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -360,7 +367,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 @@ -371,7 +378,7 @@ module { } } -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> ((d0 mod 16) * 4)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> ((d0 mod 16) * 4), // CHECK-LABEL: @remainder_with_modulo // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] @@ -380,7 +387,8 @@ module { // ----- -#map = affine_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0)> +#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> module { func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -390,7 +398,7 @@ module { %cst = arith.constant 0.0 : f32 %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] + %idx = xla_gpu.apply_indexing #map(%i)[%j] %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> %added = arith.addf %iter1, %extracted : f32 scf.yield %added : f32 diff --git a/xla/service/gpu/fusions/reduction_mlir_test.cc b/xla/service/gpu/fusions/reduction_mlir_test.cc index 761ecb4f31fe59..04b26d8dedc581 100644 --- a/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -399,8 +399,8 @@ TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) { ROOT fusion = f32[100] fusion(a, c), kind=kInput, calls=fused_computation })"; TF_EXPECT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0)> - // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512)> + // CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 3]> + // CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512), domain: d0 in [0, 1], d1 in [0, 255]> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -408,10 +408,10 @@ TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) { // CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] // CHECK-NOT: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 1], %thread_id_x in [0, 255])[%[[I]] in [0, 3]] + // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]], %thread_id_x)[%[[I]]] // CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]]) // CHECK: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]] in [0, 1], %thread_id_x in [0, 255]) + // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]], %thread_id_x) )")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } diff --git a/xla/service/gpu/fusions/scatter_mlir_test.cc b/xla/service/gpu/fusions/scatter_mlir_test.cc index 6b8d013a81f735..9220701092a522 100644 --- a/xla/service/gpu/fusions/scatter_mlir_test.cc +++ b/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -187,8 +187,8 @@ TEST_F(MlirScatterFusionTest, Scatter_UniqueIndices) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 floordiv 2)> - // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 mod 2)> + // CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 2) + // CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2) // CHECK-LABEL: func.func @fused_computation( // CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32> diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index 8f07bbace4b3ad..d59b6cc48c7ac0 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -214,7 +214,7 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127) CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK: %[[PID:.*]] = tt.get_program_id x : i32 CHECK: arith.index_castui %[[PID]] : i32 to index @@ -272,7 +272,7 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127) CHECK-LABEL: tt.func @triton_fn( CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -339,7 +339,7 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 1, 127}), "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127) CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index @@ -517,7 +517,7 @@ ENTRY main { TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 1, 16}), "triton_softmax_computation", R"( -// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 16)> +// CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 16) // CHECK-LABEL: tt.func @triton_fn( // CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr // CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -674,7 +674,7 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127) CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK: %[[PID:.*]] = tt.get_program_id x : i32 CHECK: arith.index_castui %[[PID]] : i32 to index From aabd855392b0cd9c14569bd59d258f0d806204c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Aug 2024 01:29:21 -0700 Subject: [PATCH 357/376] Automated Code Change PiperOrigin-RevId: 658310275 --- xla/service/cpu/runtime/collective_thunk.cc | 8 -------- xla/service/cpu/runtime/collective_thunk.h | 3 --- 2 files changed, 11 deletions(-) diff --git a/xla/service/cpu/runtime/collective_thunk.cc b/xla/service/cpu/runtime/collective_thunk.cc index a0cd9f4936cb33..32a452a6bcdd0d 100644 --- a/xla/service/cpu/runtime/collective_thunk.cc +++ b/xla/service/cpu/runtime/collective_thunk.cc @@ -205,10 +205,6 @@ const Shape& CollectiveThunk::source_shape(int64_t index) const { return op_buffers_.source_shapes[index]; } -absl::Span CollectiveThunk::source_shapes() const { - return op_buffers_.source_shapes; -} - const BufferAllocation::Slice& CollectiveThunk::destination_buffer( int64_t index) const { return op_buffers_.destination_buffers[index]; @@ -223,8 +219,4 @@ const Shape& CollectiveThunk::destination_shape(int64_t index) const { return op_buffers_.destination_shapes[index]; } -absl::Span CollectiveThunk::destination_shapes() const { - return op_buffers_.destination_shapes; -} - } // namespace xla::cpu diff --git a/xla/service/cpu/runtime/collective_thunk.h b/xla/service/cpu/runtime/collective_thunk.h index 5bcf16b4e10d5c..5ae9c98844f887 100644 --- a/xla/service/cpu/runtime/collective_thunk.h +++ b/xla/service/cpu/runtime/collective_thunk.h @@ -77,7 +77,6 @@ class CollectiveThunk : public Thunk { OpBuffers op_buffers, OpResources op_resources); const OpParams& op_params() const { return op_params_; } - const OpBuffers& op_buffers() const { return op_buffers_; } // Resolves operation's device memory from the buffers and buffer allocations. absl::StatusOr GetOpDeviceMemory(const ExecuteParams& params); @@ -109,13 +108,11 @@ class CollectiveThunk : public Thunk { absl::Span source_buffers() const; const Shape& source_shape(int64_t index) const; - absl::Span source_shapes() const; const BufferAllocation::Slice& destination_buffer(int64_t index) const; absl::Span destination_buffers() const; const Shape& destination_shape(int64_t index) const; - absl::Span destination_shapes() const; private: OpParams op_params_; From 811aab1c00b7998f8e03137493196cb56b3b4f2a Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 1 Aug 2024 01:40:04 -0700 Subject: [PATCH 358/376] Enable mlir reduction emitter by default. This can be disabled with the flag --xla_gpu_mlir_emitter_level, setting it to any value < 4. Change some tests to still use the old emitters. We have separate IR tests for the new emitters, and keeping the old tests running with the old emitters ensures we still have coverage for the old emitters, in case we need to rollback. One notable change with enabling emitter level 4 is that the heuristic to avoid code duplication due to cache invalidation is disabled. This was always a a workaround, and the new emitters fixed the problem. This is the most common source of why the tests behave differently between the old and the new emitters. PiperOrigin-RevId: 658313306 --- xla/debug_options_flags.cc | 2 +- xla/service/gpu/fusion_merger_test.cc | 56 +++++++++++-------- .../gpu/model/gpu_performance_model_test.cc | 12 ++-- xla/service/gpu/multi_output_fusion_test.cc | 2 + xla/service/gpu/priority_fusion_test.cc | 2 + .../gpu/tests/gpu_kernel_tiling_test.cc | 9 ++- 6 files changed, 51 insertions(+), 32 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 79e716b7aca4bb..6a664252577965 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -246,7 +246,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_p2p_max_nchannels(0); #if GOOGLE_CUDA - opts.set_xla_gpu_mlir_emitter_level(3); + opts.set_xla_gpu_mlir_emitter_level(4); #else opts.set_xla_gpu_mlir_emitter_level(0); #endif diff --git a/xla/service/gpu/fusion_merger_test.cc b/xla/service/gpu/fusion_merger_test.cc index de45a4b9d3273e..1afe8f615d0c2f 100644 --- a/xla/service/gpu/fusion_merger_test.cc +++ b/xla/service/gpu/fusion_merger_test.cc @@ -135,42 +135,42 @@ f32add { } comp0 { - p = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(0) - gte0 = f32[100000000] get-tuple-element(p), index=0 - gte1 = f32[100000000] get-tuple-element(p), index=1 - add.9 = f32[100000000] add(gte0, gte1) - gte2 = f32[100000000] get-tuple-element(p), index=2 - add.10 = f32[100000000] add(add.9, gte2) - gte3 = f32[100000000] get-tuple-element(p), index=3 - add.11 = f32[100000000] add(add.10, gte3) - p1 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(1) - gte4 = f32[100000000] get-tuple-element(p1), index=0 - gte5 = f32[100000000] get-tuple-element(p1), index=1 - add.12 = f32[100000000] add(gte4, gte5) - gte6 = f32[100000000] get-tuple-element(p1), index=2 - add.13 = f32[100000000] add(add.12, gte6) - gte7 = f32[100000000] get-tuple-element(p1), index=3 - add.14 = f32[100000000] add(add.13, gte7) - ROOT r = f32[100000000] add(add.14, add.11) + p = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(0) + gte0 = f32[2048] get-tuple-element(p), index=0 + gte1 = f32[2048] get-tuple-element(p), index=1 + add.9 = f32[2048] add(gte0, gte1) + gte2 = f32[2048] get-tuple-element(p), index=2 + add.10 = f32[2048] add(add.9, gte2) + gte3 = f32[2048] get-tuple-element(p), index=3 + add.11 = f32[2048] add(add.10, gte3) + p1 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(1) + gte4 = f32[2048] get-tuple-element(p1), index=0 + gte5 = f32[2048] get-tuple-element(p1), index=1 + add.12 = f32[2048] add(gte4, gte5) + gte6 = f32[2048] get-tuple-element(p1), index=2 + add.13 = f32[2048] add(add.12, gte6) + gte7 = f32[2048] get-tuple-element(p1), index=3 + add.14 = f32[2048] add(add.13, gte7) + ROOT r = f32[2048] add(add.14, add.11) } comp1 { - p = f32[100000000] parameter(0) + p = f32[2048] parameter(0) c0 = f32[] constant(0) ROOT r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add } comp2 { - p = f32[100000000] parameter(0) + p = f32[2048] parameter(0) c0 = f32[] constant(0) r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add ROOT n = f32[] negate(r) } ENTRY m.Computation2 { - p0 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(0) - p1 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(1) - fusion.0 = f32[100000000] fusion(p0, p1), kind=kLoop, calls=comp0 + p0 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(0) + p1 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(1) + fusion.0 = f32[2048] fusion(p0, p1), kind=kLoop, calls=comp0 fusion.1 = f32[] fusion(fusion.0), kind=kLoop, calls=comp1 fusion.2 = f32[] fusion(fusion.0), kind=kLoop, calls=comp2 ROOT tuple = (f32[], f32[]) tuple(fusion.1, fusion.2) @@ -362,14 +362,14 @@ TEST_F(FusionMergerTest, WillMergeReduceNotTooUnfriendlyLayouts) { f2_computation { f2_p0 = f32[16,16,256]{2,1,0} parameter(0) f2_zero = f32[] constant(0) - ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2}, + ROOT f2_root = f32[16,16] reduce(f2_p0, f2_zero), dimensions={2}, to_apply=add_computation } ENTRY entry { p0 = f32[16,16,256]{0,1,2} parameter(0) f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation - ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation + ROOT f2 = f32[16,16] fusion(f1), kind=kInput, calls=f2_computation })") .value(); EXPECT_TRUE(fusion_merger_.Run(module.get()).value()); @@ -685,6 +685,12 @@ ENTRY entry { } )") .value(); + auto& debug_options = module->mutable_config().mutable_debug_options(); + // For some reason, we would not merge any fusions when using the MLIR + // reduction emitter. The cost model queries the reduction emitter regarding + // the launch dimensions, so it seems likely that it is caused by different + // launch dimensions. + debug_options.set_xla_gpu_mlir_emitter_level(3); EXPECT_TRUE(fusion_merger_.Run(module.get()).value()); } @@ -995,6 +1001,8 @@ ENTRY e { } )") .value(); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); EXPECT_FALSE(fusion_merger_.Run(module.get()).value()); } diff --git a/xla/service/gpu/model/gpu_performance_model_test.cc b/xla/service/gpu/model/gpu_performance_model_test.cc index 4c0c35e1a9e285..33a43f736d442a 100644 --- a/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/xla/service/gpu/model/gpu_performance_model_test.cc @@ -674,16 +674,16 @@ add { } fused_computation.0 { - p0 = f32[4,28672,32] parameter(0) - tanh = f32[4,28672,32] tanh(p0) + p0 = f32[4,256,32] parameter(0) + tanh = f32[4,256,32] tanh(p0) c1 = f32[] constant(72) - broadcast = f32[4,28672,32] broadcast(c1), dimensions={} - ROOT mul = f32[4,28672,32] multiply(tanh, broadcast) + broadcast = f32[4,256, 32] broadcast(c1), dimensions={} + ROOT mul = f32[4,256,32] multiply(tanh, broadcast) } ENTRY fusion { - p0 = f32[4,28672,32] parameter(0) - fusion = f32[4,28672,32] fusion(p0), kind=kLoop, calls=fused_computation.0 + p0 = f32[4,256,32] parameter(0) + fusion = f32[4,256,32] fusion(p0), kind=kLoop, calls=fused_computation.0 c0 = f32[] constant(0) ROOT reduce = f32[4,32] reduce(fusion, c0), to_apply=add, dimensions={1} })"; diff --git a/xla/service/gpu/multi_output_fusion_test.cc b/xla/service/gpu/multi_output_fusion_test.cc index b333a04a841882..3cbaa26d49d723 100644 --- a/xla/service/gpu/multi_output_fusion_test.cc +++ b/xla/service/gpu/multi_output_fusion_test.cc @@ -1529,6 +1529,8 @@ ENTRY main { } )") .value(); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); EXPECT_FALSE(mof_.Run(module.get()).value()); } diff --git a/xla/service/gpu/priority_fusion_test.cc b/xla/service/gpu/priority_fusion_test.cc index 4f71a51b869b4f..b6a16249fa069d 100644 --- a/xla/service/gpu/priority_fusion_test.cc +++ b/xla/service/gpu/priority_fusion_test.cc @@ -856,6 +856,8 @@ TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) { ROOT fusion2 = pred[6]{0} fusion(fusion1), kind=kInput, calls=fused_computation.2 } )"); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); } diff --git a/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index e86f2c09b06cea..45e7af622d9814 100644 --- a/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -34,7 +34,11 @@ class GpuKernelTilingTest : public GpuCodegenTest { // Most tests in this file want to skip layout assignment, but a few need it // enabled. HloModuleConfig ConfigWithLayoutAssignment() { - return GetModuleConfigForTest(); + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + config.set_debug_options(debug_options); + return config; } HloModuleConfig ConfigWithoutLayoutAssignment() { @@ -42,6 +46,7 @@ class GpuKernelTilingTest : public GpuCodegenTest { auto debug_options = HloTestBase::GetDebugOptionsForTest(); // Disable layout_assignment to use the preassigned layouts. debug_options.add_xla_disable_hlo_passes("layout-assignment"); + debug_options.set_xla_gpu_mlir_emitter_level(3); config.set_debug_options(debug_options); return config; } @@ -635,6 +640,8 @@ TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) { } )"; auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); + auto &debug_options = hlo_module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); auto expected_ir = is_built_with_rocm_ ? R"( ; CHECK: %llvm.amdgcn.kernel.input_reduce_fusion.lds.t = type { [4 x [2 x float]] } ; CHECK: @llvm.amdgcn.kernel.input_reduce_fusion.lds = internal addrspace(3) global %llvm.amdgcn.kernel.input_reduce_fusion.lds.t poison From e116e85b78071eec140091d0066bc4fee1f887b3 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Thu, 1 Aug 2024 02:48:15 -0700 Subject: [PATCH 359/376] Clean up uses of bounds in ApplyIndexingOp since they are no longer necessary PiperOrigin-RevId: 658330538 --- .../gpu/fusions/mlir/ir/xla_gpu_ops.cc | 72 ++++--------------- .../gpu/fusions/mlir/ir/xla_gpu_ops.td | 6 -- .../gpu/fusions/mlir/optimize_loops.cc | 6 +- .../fusions/mlir/vectorize_loads_stores.cc | 3 +- 4 files changed, 18 insertions(+), 69 deletions(-) diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc index 57c1da7be8558b..78a04976ebebf7 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -140,28 +140,6 @@ void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, build(builder, result, operands, indexing_map); } -void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, - ValueRange operands, AffineMap affine_map, - ArrayRef lower_bounds, - ArrayRef upper_bounds) { - unsigned num_dimensions = affine_map.getNumDims(); - std::vector dim_vars; - dim_vars.reserve(num_dimensions); - for (unsigned id = 0; id < num_dimensions; ++id) { - dim_vars.push_back(DimVar{Interval{lower_bounds[id], upper_bounds[id]}}); - } - unsigned num_symbols = affine_map.getNumSymbols(); - std::vector range_vars; - range_vars.reserve(num_symbols); - for (unsigned id = num_dimensions; id < num_symbols + num_dimensions; ++id) { - range_vars.push_back( - RangeVar{Interval{lower_bounds[id], upper_bounds[id]}}); - } - IndexingMap indexing_map(affine_map, std::move(dim_vars), - std::move(range_vars), /*rt_vars=*/{}); - build(builder, result, operands, indexing_map); -} - // Parses a comma-separated list of operands, ex: %d1, %d2. mlir::ParseResult parseOperands( mlir::OpAsmParser& parser, @@ -242,30 +220,6 @@ LogicalResult ApplyIndexingOp::verify() { return success(); } -llvm::SmallVector ApplyIndexingOp::getLowerBounds() { - SmallVector lower_bounds; - lower_bounds.reserve(getNumOperands()); - for (const auto& dim_var : getIndexingMapAttr().getDimVars()) { - lower_bounds.push_back(dim_var.bounds.lower); - } - for (const auto& range_var : getIndexingMapAttr().getRangeVars()) { - lower_bounds.push_back(range_var.range.lower); - } - return lower_bounds; -} - -llvm::SmallVector ApplyIndexingOp::getUpperBounds() { - SmallVector upper_bounds; - upper_bounds.reserve(getNumOperands()); - for (const auto& dim_var : getIndexingMapAttr().getDimVars()) { - upper_bounds.push_back(dim_var.bounds.upper); - } - for (const auto& range_var : getIndexingMapAttr().getRangeVars()) { - upper_bounds.push_back(range_var.range.upper); - } - return upper_bounds; -} - IndexingMap ApplyIndexingOp::getIndexingMap() { return IndexingMap(getIndexingMapAttr().getMap(), getIndexingMapAttr().getDimVars(), @@ -418,7 +372,8 @@ struct FoldApplyIndexingOperands LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, PatternRewriter& rewriter) const override { - AffineMap affine_map = indexing_op.getAffineMap(); + IndexingMap indexing_map = indexing_op.getIndexingMap(); + AffineMap affine_map = indexing_map.GetAffineMap(); MLIRContext* ctx = affine_map.getContext(); unsigned num_operands = indexing_op->getNumOperands(); @@ -428,8 +383,6 @@ struct FoldApplyIndexingOperands SmallVector> constant_values(num_operands, std::nullopt); int num_constants = 0; - SmallVector dim_id_map(num_dims, -1); - SmallVector symbol_id_map(num_symbols, -1); for (auto& operand : indexing_op->getOpOperands()) { if (auto constant = operand.get().getDefiningOp()) { @@ -448,15 +401,15 @@ struct FoldApplyIndexingOperands unsigned new_num_operands = indexing_op->getNumOperands() - num_constants; SmallVector new_operands; new_operands.reserve(new_num_operands); - SmallVector new_lbs, new_ubs; - new_lbs.reserve(new_num_operands); - new_ubs.reserve(new_num_operands); + SmallVector new_dim_vars; + new_dim_vars.reserve(num_dims); + SmallVector new_range_vars; + new_range_vars.reserve(num_symbols); unsigned new_num_dims = 0; unsigned new_num_symbols = 0; - for (auto [operand, constant_value, lb, ub] : llvm::zip( - indexing_op->getOpOperands(), constant_values, - indexing_op.getLowerBounds(), indexing_op.getUpperBounds())) { + for (auto [operand, constant_value] : + llvm::zip(indexing_op->getOpOperands(), constant_values)) { unsigned operand_id = operand.getOperandNumber(); if (constant_value.has_value()) { if (operand_id < num_dims) { @@ -467,22 +420,23 @@ struct FoldApplyIndexingOperands getAffineConstantExpr(*constant_value, ctx)); } } else { + new_operands.push_back(operand.get()); if (operand_id < num_dims) { dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx)); + new_dim_vars.push_back(indexing_map.GetDimVars(operand_id)); } else { symbol_replacements.push_back( getAffineSymbolExpr(new_num_symbols++, ctx)); + new_range_vars.push_back( + indexing_map.GetRangeVar(operand_id - num_dims)); } - new_operands.push_back(operand.get()); - new_lbs.push_back(lb); - new_ubs.push_back(ub); } } rewriter.replaceOpWithNewOp( indexing_op, new_operands, affine_map.replaceDimsAndSymbols(dim_replacements, symbol_replacements, new_num_dims, new_num_symbols), - new_lbs, new_ubs); + new_dim_vars, new_range_vars); return success(); } }; diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td index ed70f2ced60fcc..836d4d31b5bb18 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td @@ -254,16 +254,10 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> { OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map, "llvm::ArrayRef":$dim_vars, "llvm::ArrayRef":$range_vars)>, - OpBuilder<(ins "mlir::ValueRange":$operands, - "mlir::AffineMap":$affine_map, - "llvm::ArrayRef":$lower_bounds, - "llvm::ArrayRef":$upper_bounds)>, ]; let extraClassDeclaration = [{ // Returns the indexing map constructed from affine_map and the bounds. xla::gpu::IndexingMap getIndexingMap(); - llvm::SmallVector getLowerBounds(); - llvm::SmallVector getUpperBounds(); // Extracts the affine map from the attribute. mlir::AffineMap getAffineMap() { return getIndexingMapAttr().getMap(); } }]; diff --git a/xla/service/gpu/fusions/mlir/optimize_loops.cc b/xla/service/gpu/fusions/mlir/optimize_loops.cc index 6d5456f0150323..e32c9f5a7dce39 100644 --- a/xla/service/gpu/fusions/mlir/optimize_loops.cc +++ b/xla/service/gpu/fusions/mlir/optimize_loops.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/model/indexing_map.h" namespace xla { namespace gpu { @@ -191,9 +192,10 @@ struct PipelineLoad : mlir::OpRewritePattern { auto plus_one_map = mlir::AffineMap::get( 1, 0, mlir::getAffineDimExpr(0, this->getContext()) + 1); b.setInsertionPoint(next_value); + IndexingMap indexing_map(plus_one_map, {DimVar{0, ub.getSExtValue() - 1}}, + /*range_vars=*/{}, /*rt_vars=*/{}); auto induction_plus_one = - b.create(new_for.getInductionVar(), plus_one_map, 0, - ub.getSExtValue() - 1) + b.create(new_for.getInductionVar(), indexing_map) ->getResult(0); // Create the new apply_indexing ops outside the if, to improve CSE. diff --git a/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc b/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc index 00079845867fb0..119b338057e859 100644 --- a/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc +++ b/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc @@ -214,8 +214,7 @@ std::optional> GetVectorBaseIndices( llvm::SmallVector ret = indices; ret.back() = - b.create(operands, map, apply_indexing.getLowerBounds(), - apply_indexing.getUpperBounds()) + b.create(operands, apply_indexing.getIndexingMap()) ->getResult(0); return ret; } From 68608e5a00113f31937a9000de8a97ab3d6fbe4f Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Thu, 1 Aug 2024 05:41:33 -0700 Subject: [PATCH 360/376] Put the need to clear constraints to the caller of ApplyIndexingOp PiperOrigin-RevId: 658369726 --- xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc | 3 ++- xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h | 2 +- xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc | 3 --- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 59471a3fb337ea..ec589f299fc738 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -665,9 +665,10 @@ Value ApplyAffineExpr(mlir::AffineExpr expr, ValueRange dims, return b.createOrFold(expr, args); } -SmallVector ApplyIndexing(const IndexingMap& map, ValueRange dims, +SmallVector ApplyIndexing(IndexingMap map, ValueRange dims, ValueRange symbols, ImplicitLocOpBuilder& b) { + map.ClearConstraints(); SmallVector results; for (unsigned int i = 0; i < map.GetAffineMap().getNumResults(); ++i) { SmallVector result; diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h index 1f52109e34c883..99941476828f3f 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h @@ -94,7 +94,7 @@ mlir::Value ApplyAffineExpr(mlir::AffineExpr expr, mlir::ValueRange dims, mlir::ImplicitLocOpBuilder& b); // Creates an `apply_indexing` op for the given map. -llvm::SmallVector ApplyIndexing(const IndexingMap& map, +llvm::SmallVector ApplyIndexing(IndexingMap map, mlir::ValueRange dims, mlir::ValueRange symbols, mlir::ImplicitLocOpBuilder& b); diff --git a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc index 78a04976ebebf7..11b98059fefb2a 100644 --- a/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc @@ -124,9 +124,6 @@ void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, ValueRange operands, IndexingMap indexing_map) { SmallVector result_types(indexing_map.GetAffineMap().getNumResults(), builder.getIndexType()); - // ApplyIndexingOp cannot have any constraints. It may be better to enforce - // callers to do this, but for now this follows the previous behavior. - indexing_map.ClearConstraints(); IndexingMapAttr indexing_map_attr = IndexingMapAttr::get(builder.getContext(), indexing_map); build(builder, result, result_types, operands, indexing_map_attr); From b1960d347238bc5369c67b2bc7f719e30c067f88 Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Thu, 1 Aug 2024 06:39:10 -0700 Subject: [PATCH 361/376] Integrate LLVM at llvm/llvm-project@17ba4f4053e3 Updates LLVM usage to match [17ba4f4053e3](https://github.com/llvm/llvm-project/commit/17ba4f4053e3) PiperOrigin-RevId: 658383204 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/temporary.patch | 10 +++++----- third_party/shardy/workspace.bzl | 4 ++-- third_party/tsl/third_party/llvm/workspace.bzl | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 6429d9bd82a98c..e0f0b45d1f46d4 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d" - LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd" + LLVM_COMMIT = "17ba4f4053e303be3e5408d34eaf687a49cefb06" + LLVM_SHA256 = "64c334c15f058ca090fe8eb1e2cd99fdcbaaaf6e57202760f32574d3a9d24d78" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 4d99610ad94bd8..aa892f8ca2205b 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +1,15 @@ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 9345d8d..6429d9b 100644 +index 6429d9b..e0f0b45 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "d92a484e6f5c9063d82ca79405bb3557d88ad575" -- LLVM_SHA256 = "0e6cce920f7344248ed747443fc16c316faf398e33f6a7f9f11f41ede861f824" -+ LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d" -+ LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd" +- LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d" +- LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd" ++ LLVM_COMMIT = "17ba4f4053e303be3e5408d34eaf687a49cefb06" ++ LLVM_SHA256 = "64c334c15f058ca090fe8eb1e2cd99fdcbaaaf6e57202760f32574d3a9d24d78" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index 200ac3f5fbd5a3..dc1800d246361c 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "df54e37427b0007e6527b62616ed1f66a68dda4a" - SHARDY_SHA256 = "2ebf03fd73c4578e721c539ad05b33d5fbfae6838abbb58b944e12f1eafbd9b2" + SHARDY_COMMIT = "a0d337eecb4957da862b235a2829efd9513a129c" + SHARDY_SHA256 = "ffbc55d51995da6fd149e7bc0e41bee4faba02dfa2984f2af54940c46578c0c7" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index 6429d9bd82a98c..e0f0b45d1f46d4 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d" - LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd" + LLVM_COMMIT = "17ba4f4053e303be3e5408d34eaf687a49cefb06" + LLVM_SHA256 = "64c334c15f058ca090fe8eb1e2cd99fdcbaaaf6e57202760f32574d3a9d24d78" tf_http_archive( name = name, From f104a1db09d054d36ed817b4d2ae48c38f21966c Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 1 Aug 2024 06:50:50 -0700 Subject: [PATCH 362/376] [xla:cpu] Partial rollback of PR #15198 (turning on thunks interpreter) due to CI breakage. Switch back the classic runtime. But keep other test changes in the original PR. Reverts 331d668cf959e8fab53d328ede7ff9b65bd0e1ec PiperOrigin-RevId: 658385847 --- xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 6a664252577965..ee2f445a382f2e 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -82,7 +82,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { #ifdef XLA_CPU_USE_ACL opts.set_xla_cpu_use_acl(true); #endif - opts.set_xla_cpu_use_thunk_runtime(true); + opts.set_xla_cpu_use_thunk_runtime(false); opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false); opts.set_xla_cpu_prefer_vector_width(256); From c83a127e3f8d2ade4bb9a193317051dd94da15b2 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 1 Aug 2024 07:44:52 -0700 Subject: [PATCH 363/376] [XLA:GPU] Move `SKIP_TEST_IF_NUM_DEVICES_LESS_THAN` macro to hlo_test_base PiperOrigin-RevId: 658398507 --- xla/tests/collective_ops_test.cc | 83 +++++++++---------- .../collective_pipeline_parallelism_test.cc | 20 ++--- xla/tests/hlo_test_base.h | 8 ++ 3 files changed, 50 insertions(+), 61 deletions(-) diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 460864c7513269..9cd874c9e03c13 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -39,23 +39,17 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" +namespace xla { +namespace { + // Tests cross-GPU operations. // // Several tests requires at least four GPUs. For instructions on running this // within Google, see go/multi-gpu-unit-test. - -#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \ - if (num_devices_ < x) { \ - GTEST_SKIP() << "Test requires at least " << x << " devices"; \ - } - -namespace xla { -namespace { - class CollectiveOpsTest : public HloTestBase { public: - CollectiveOpsTest() : num_devices_(backend().device_count()) { - VLOG(1) << "Running with " << num_devices_ << " devices"; + CollectiveOpsTest() { + VLOG(1) << "Running with " << num_devices() << " devices"; } protected: @@ -180,9 +174,6 @@ class CollectiveOpsTest : public HloTestBase { /*expected_value=*/to_literal({cast(-1), cast(-2), cast(-3)})); } } - - protected: - const int64_t num_devices_; }; // Returns the non-empty subsets of {0, 1, ..., n}. For example, @@ -370,7 +361,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) { XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) { const int64_t kNumElems = 1024; - for (std::vector devices : PowerSetOfIota(num_devices_)) { + for (std::vector devices : PowerSetOfIota(num_devices())) { SCOPED_TRACE(absl::StrFormat("Running on devices {%s}", absl::StrJoin(devices, ", "))); @@ -494,7 +485,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) { // Test a prime number so it's not all powers of 2. const int64_t kNumElems = 137; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -541,7 +532,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_Degenerate) { } )"; static constexpr int kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -577,19 +568,19 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduce)) { )"; HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/num_devices_); + GetModuleConfigForTest(/*replica_count=*/num_devices()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), absl::Span{}, - num_devices_, + num_devices(), /*use_threads=*/true, /*run_hlo_passes=*/false)); - ASSERT_EQ(results.size(), num_devices_); + ASSERT_EQ(results.size(), num_devices()); // sum [0, num_devices) - uint32_t expected = num_devices_ * (num_devices_ - 1) / 2; - for (int i = 0; i < num_devices_; ++i) { + uint32_t expected = num_devices() * (num_devices() - 1) / 2; + for (int i = 0; i < num_devices(); ++i) { LiteralTestUtil::ExpectR0Equal(expected, results[i]); } } @@ -613,22 +604,22 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduceTwoOperands)) { )"; HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/num_devices_); + GetModuleConfigForTest(/*replica_count=*/num_devices()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), absl::Span{}, - num_devices_, + num_devices(), /*use_threads=*/true, /*run_hlo_passes=*/false)); - ASSERT_EQ(results.size(), num_devices_); + ASSERT_EQ(results.size(), num_devices()); // sum [0, num_devices) - uint32_t expected0 = num_devices_ * (num_devices_ - 1) / 2; + uint32_t expected0 = num_devices() * (num_devices() - 1) / 2; // sum squares [0, num_devices) uint32_t expected1 = - num_devices_ * (num_devices_ - 1) * (2 * num_devices_ - 1) / 6; - for (int i = 0; i < num_devices_; ++i) { + num_devices() * (num_devices() - 1) * (2 * num_devices() - 1) / 6; + for (int i = 0; i < num_devices(); ++i) { std::vector replica_results = results[i].DecomposeTuple(); LiteralTestUtil::ExpectR0Equal(expected0, replica_results[0]); LiteralTestUtil::ExpectR0Equal(expected1, replica_results[1]); @@ -645,18 +636,18 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) { )"; HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/num_devices_); + GetModuleConfigForTest(/*replica_count=*/num_devices()); TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), absl::Span{}, - num_devices_, + num_devices(), /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), num_devices_); - for (uint32_t i = 0; i < num_devices_; ++i) { + ASSERT_EQ(results.size(), num_devices()); + for (uint32_t i = 0; i < num_devices(); ++i) { EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i])); } } @@ -680,7 +671,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_Simple)) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -716,7 +707,7 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -753,7 +744,7 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degenerate) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -789,7 +780,7 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_NotDegenerate) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -826,7 +817,7 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Rotate) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -864,7 +855,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncCollectivePermute)) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -906,7 +897,7 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_EmptyReplicaGroups) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -952,7 +943,7 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_OrderedReplicaGroups) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -992,7 +983,7 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_TwoReplicaGroups) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1024,7 +1015,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2003,7 +1994,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_Simple)) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2083,7 +2074,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_TwoConcurrentChains)) { })"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2162,7 +2153,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_ValidationAttr1)) { })"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2263,7 +2254,7 @@ body { })"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index bfcf5e14adf5d1..3db920ad6f75f2 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -32,28 +32,18 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" +namespace xla { +namespace { + // Tests cross-GPU operations. // // Several tests requires at least four GPUs. For instructions on running this // within Google, see go/multi-gpu-unit-test. - -// TODO: Move this to hlo_test_base.h -#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \ - if (num_devices_ < x) { \ - GTEST_SKIP() << "Test requires at least " << x << " devices"; \ - } - -namespace xla { -namespace { - class CollectivePipelineParallelismTest : public HloTestBase { public: - CollectivePipelineParallelismTest() : num_devices_(backend().device_count()) { - VLOG(1) << "Running with " << num_devices_ << " devices"; + CollectivePipelineParallelismTest() { + VLOG(1) << "Running with " << num_devices() << " devices"; } - - protected: - const int64_t num_devices_; }; XLA_TEST_F(CollectivePipelineParallelismTest, diff --git a/xla/tests/hlo_test_base.h b/xla/tests/hlo_test_base.h index 9858ed6f53997d..9774c06f0d3991 100644 --- a/xla/tests/hlo_test_base.h +++ b/xla/tests/hlo_test_base.h @@ -438,6 +438,7 @@ class HloTestBase : public ManifestCheckingTest { // Returns the backend owned by the test runner. Backend& backend(); + int64_t num_devices() { return backend().device_count(); } HloRunner test_runner_; HloRunner reference_runner_; @@ -513,6 +514,13 @@ class HloTestBase : public ManifestCheckingTest { se::Platform* test_platform); }; +#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \ + int64_t num_devices = backend().device_count(); \ + if (num_devices < x) { \ + GTEST_SKIP() << "Test requires at least " << x << " devices (" \ + << num_devices << " available)"; \ + } + } // namespace xla #endif // XLA_TESTS_HLO_TEST_BASE_H_ From ce96385b5afd3605c0803798b9b94669613f4118 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 1 Aug 2024 07:58:35 -0700 Subject: [PATCH 364/376] Reverts 99353d88f96cb378562e396c9829ceb3032ff848 PiperOrigin-RevId: 658402030 --- xla/backends/interpreter/executor.h | 2 +- xla/stream_executor/cuda/BUILD | 1 - xla/stream_executor/cuda/cuda_executor.cc | 103 ++++-------------- xla/stream_executor/gpu/BUILD | 2 - xla/stream_executor/gpu/gpu_executor.h | 24 ++-- xla/stream_executor/gpu/gpu_executor_test.cc | 29 ----- xla/stream_executor/host/host_executor.h | 2 +- xla/stream_executor/host_memory_allocation.cc | 2 +- .../integrations/device_mem_allocator.h | 2 +- xla/stream_executor/mock_stream_executor.h | 3 +- xla/stream_executor/rocm/rocm_executor.cc | 14 --- xla/stream_executor/stream_executor.h | 2 +- xla/stream_executor/tpu/tpu_executor.h | 2 +- 13 files changed, 44 insertions(+), 144 deletions(-) diff --git a/xla/backends/interpreter/executor.h b/xla/backends/interpreter/executor.h index 0aee389be7bf94..3d2e89dd17aaba 100644 --- a/xla/backends/interpreter/executor.h +++ b/xla/backends/interpreter/executor.h @@ -103,7 +103,7 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { uint64_t size) override { return std::make_unique(new char[size], size, this); } - void HostMemoryDeallocate(void *mem, uint64_t size) override { + void HostMemoryDeallocate(void *mem) override { delete[] static_cast(mem); } diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index ea7dc5431f865f..48592a1a92656b 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -824,7 +824,6 @@ cuda_only_cc_library( "@tsl//tsl/platform:errors", "@tsl//tsl/platform:fingerprint", "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:statusor", ] + if_cuda_is_configured([":delay_kernel_cuda"]), alwayslink = True, diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index b624709c22b2f8..0b90f27b8811d9 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -86,7 +86,6 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/logging.h" -#include "tsl/platform/numa.h" #include "tsl/platform/statusor.h" // LOG(ERROR) uses a const named ERROR, so a macro with the same name is @@ -154,9 +153,6 @@ GpuExecutor::~GpuExecutor() { } } -static std::optional TryToReadNumaNode(const std::string& pci_bus_id, - int device_ordinal); - absl::Status GpuExecutor::Init() { TF_RETURN_IF_ERROR(GpuDriver::Init()); TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal_, &device_)); @@ -164,17 +160,6 @@ absl::Status GpuExecutor::Init() { GpuDriver::CreateContext(device_ordinal_, device_, &context_)); TF_RETURN_IF_ERROR( GpuDriver::GetComputeCapability(&cc_major_, &cc_minor_, device_)); - std::optional numa_node = TryToReadNumaNode( - absl::AsciiStrToLower(GpuDriver::GetPCIBusID(device_ordinal_)), - device_ordinal_); - if (!numa_node || *numa_node < 0) { - LOG(WARNING) << "NUMA node could not be determined for device " - << device_ordinal_ - << ", host memory allocations will not be NUMA-pinned"; - numa_node_ = tsl::port::kNUMANoAffinity; - } else { - numa_node_ = *numa_node; - } return absl::OkStatus(); } @@ -588,47 +573,6 @@ void GpuExecutor::Deallocate(DeviceMemoryBase* mem) { GpuDriver::DeviceDeallocate(context_, mem->opaque()); } -// CUDA allocation/registration functions are necessary because the driver -// internally sets up buffers for DMA operations (and page locks them). There's -// no external interface for us to otherwise control these DMA settings. -absl::StatusOr> -GpuExecutor::HostMemoryAllocate(uint64_t size) { - if (numa_node_ != tsl::port::kNUMANoAffinity) { - auto* buffer = - tsl::port::NUMAMalloc(numa_node_, size, /* minimum_alignment=*/16); - if (buffer == nullptr && size > 0) { - return absl::InternalError(absl::StrFormat( - "Failed to allocate host memory of size %d pinned to NUMA node %d", - size, numa_node_)); - } - if (size > 0 && !GpuDriver::HostRegister(context_, buffer, size)) { - return absl::InternalError( - absl::StrFormat("Failed to register host memory of size %d pinned to " - "NUMA node %d with the GPU driver", - size, numa_node_)); - } - return std::make_unique(buffer, size, this); - } else { - auto* buffer = GpuDriver::HostAllocate(context_, size); - if (buffer == nullptr && size > 0) { - return absl::InternalError( - absl::StrFormat("Failed to allocate HostMemory of size %d", size)); - } - return std::make_unique(buffer, size, this); - } -} - -void GpuExecutor::HostMemoryDeallocate(void* location, uint64_t size) { - if (numa_node_ != tsl::port::kNUMANoAffinity) { - if (size > 0) { - GpuDriver::HostUnregister(context_, location); - } - tsl::port::NUMAFree(location, size); - } else { - GpuDriver::HostDeallocate(context_, location); - } -} - bool GpuExecutor::SynchronizeAllActivity() { return GpuDriver::SynchronizeContext(context_); } @@ -829,22 +773,22 @@ std::unique_ptr GpuExecutor::CreateCommandBuffer( GpuContext* GpuExecutor::gpu_context() { return context_; } // Attempts to read the NUMA node corresponding to the GPU device's PCI bus out -// of SysFS. +// of SysFS. Returns -1 if it cannot. // // For anything more complicated/prod-focused than this, you'll likely want to -// turn to gsys' topology modeling. nvmlDeviceGetMemoryAffinity could also be -// used. -static std::optional TryToReadNumaNode(const std::string& pci_bus_id, - int device_ordinal) { +// turn to gsys' topology modeling. +static int TryToReadNumaNode(const std::string& pci_bus_id, + int device_ordinal) { #if defined(PLATFORM_WINDOWS) // Windows support for NUMA is not currently implemented. Return node 0. return 0; #else VLOG(2) << "trying to read NUMA node for device ordinal: " << device_ordinal; + static const int kUnknownNumaNode = -1; if (pci_bus_id.empty()) { LOG(INFO) << "no PCI bus ID for device ordinal: " << device_ordinal; - return std::nullopt; + return kUnknownNumaNode; } std::string filename = @@ -857,7 +801,7 @@ static std::optional TryToReadNumaNode(const std::string& pci_bus_id, if (file == nullptr) { LOG(INFO) << "could not open file to read NUMA node: " << filename << "\nYour kernel may have been built without NUMA support."; - return std::nullopt; + return kUnknownNumaNode; } std::string content; @@ -868,6 +812,17 @@ static std::optional TryToReadNumaNode(const std::string& pci_bus_id, int32_t value; if (absl::SimpleAtoi(content, &value)) { + if (value < 0) { // See http://b/18228951 for details on this path. + LOG(INFO) << "successful NUMA node read from SysFS had negative value (" + << value + << "), but there must be at least one NUMA node" + ", so returning NUMA node zero." + " See more at " + "https://github.com/torvalds/linux/blob/v6.0/Documentation/" + "ABI/testing/sysfs-bus-pci#L344-L355"; + fclose(file); + return 0; + } fclose(file); return value; } @@ -877,7 +832,7 @@ static std::optional TryToReadNumaNode(const std::string& pci_bus_id, << content; fclose(file); - return std::nullopt; + return kUnknownNumaNode; #endif } @@ -909,24 +864,8 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { builder.set_pci_bus_id(pci_bus_id); // Read the NUMA node corresponding to the PCI bus ID out of sysfs. - std::optional numa_node = - TryToReadNumaNode(pci_bus_id, device_ordinal); - if (numa_node.has_value()) { - if (*numa_node < 0) { // See http://b/18228951 for details on this path. - LOG(INFO) - << "successful NUMA node read from SysFS had negative value (" - << *numa_node - << "), but there must be at least one NUMA node" - ", so returning NUMA node zero." - " See more at " - "https://github.com/torvalds/linux/blob/v6.0/Documentation/" - "ABI/testing/sysfs-bus-pci#L344-L355"; - numa_node = 0; - } - } else { - numa_node = -1; - } - builder.set_numa_node(*numa_node); + int numa_node = TryToReadNumaNode(pci_bus_id, device_ordinal); + builder.set_numa_node(numa_node); } { diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index bee21d83ac8541..403735af9fab19 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -235,7 +235,6 @@ gpu_only_cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:thread_annotations", ], ) @@ -799,7 +798,6 @@ xla_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ] + if_cuda([ diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index 116120d1e0e62b..f7dd572e918ccd 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -59,7 +59,6 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_common.h" -#include "tsl/platform/numa.h" #include "tsl/platform/thread_annotations.h" namespace stream_executor { @@ -113,8 +112,7 @@ class GpuExecutor : public StreamExecutorCommon { device_ordinal_(device_ordinal), cc_major_(0), cc_minor_(0), - version_(0), - numa_node_(tsl::port::kNUMANoAffinity) {} + version_(0) {} // See the corresponding StreamExecutor methods for method comments on the // following overrides. @@ -169,10 +167,23 @@ class GpuExecutor : public StreamExecutorCommon { return GpuCollectives::CollectiveMemoryDeallocate(context_, location); } + // CUDA allocation/registration functions are necessary because the driver + // internally sets up buffers for DMA operations (and page locks them). + // There's no external interface for us to otherwise control these DMA + // settings. absl::StatusOr> HostMemoryAllocate( - uint64_t size) override; + uint64_t size) override { + auto* buffer = GpuDriver::HostAllocate(context_, size); + if (buffer == nullptr && size > 0) { + return absl::InternalError( + absl::StrFormat("Failed to allocate HostMemory of size %d", size)); + } + return std::make_unique(buffer, size, this); + } - void HostMemoryDeallocate(void* location, uint64_t size) override; + void HostMemoryDeallocate(void* location) override { + return GpuDriver::HostDeallocate(context_, location); + } absl::StatusOr GetPointerMemorySpace(const void* ptr) override { return GpuDriver::GetPointerMemorySpace( @@ -369,9 +380,6 @@ class GpuExecutor : public StreamExecutorCommon { // GPU ISA version for device_. int version_; - // NUMA node for device_. - int numa_node_; - // Type erased XLA specific state attached to GpuExecutor. Object xla_state_; diff --git a/xla/stream_executor/gpu/gpu_executor_test.cc b/xla/stream_executor/gpu/gpu_executor_test.cc index 9ac7be1a2c2210..c3c67bc03d8884 100644 --- a/xla/stream_executor/gpu/gpu_executor_test.cc +++ b/xla/stream_executor/gpu/gpu_executor_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/numa.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -55,32 +54,4 @@ TEST_F(GetPointerMemorySpaceTest, Device) { executor->Deallocate(&mem); } -using HostMemoryAllocateTest = GpuExecutorTest; - -TEST_F(HostMemoryAllocateTest, Numa) { - Platform* platform = GetPlatform(); - const uint64_t kSize = 1024; - const int num_devices = platform->VisibleDeviceCount(); - for (int device = 0; device < num_devices; ++device) { - TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, - platform->ExecutorForDevice(device)); - ASSERT_TRUE(executor); - TF_ASSERT_OK_AND_ASSIGN(auto device_desc, - executor->CreateDeviceDescription()); - ASSERT_TRUE(device_desc); - TF_ASSERT_OK_AND_ASSIGN(auto host_ptr, executor->HostMemoryAllocate(kSize)); - ASSERT_TRUE(host_ptr); - EXPECT_NE(host_ptr->opaque(), nullptr); - const int numa_node = tsl::port::NUMAGetMemAffinity(host_ptr->opaque()); - if (numa_node == tsl::port::kNUMANoAffinity) { - // Could be because `executor` could not determine its own NUMA node, in - // which case numa_node() will be -1 or 0, depending on the failure mode. - EXPECT_LE(device_desc->numa_node(), 0); - EXPECT_GE(device_desc->numa_node(), -1); - } else { - EXPECT_EQ(device_desc->numa_node(), numa_node); - } - } -} - } // namespace stream_executor diff --git a/xla/stream_executor/host/host_executor.h b/xla/stream_executor/host/host_executor.h index 5f1c5d00a23463..7ab168d29f9d6d 100644 --- a/xla/stream_executor/host/host_executor.h +++ b/xla/stream_executor/host/host_executor.h @@ -84,7 +84,7 @@ class HostExecutor : public StreamExecutorCommon { uint64_t size) override { return std::make_unique(new char[size], size, this); } - void HostMemoryDeallocate(void* mem, uint64_t size) override { + void HostMemoryDeallocate(void* mem) override { delete[] static_cast(mem); } diff --git a/xla/stream_executor/host_memory_allocation.cc b/xla/stream_executor/host_memory_allocation.cc index 9772396b9cc61e..e77c5e8c69475c 100644 --- a/xla/stream_executor/host_memory_allocation.cc +++ b/xla/stream_executor/host_memory_allocation.cc @@ -27,7 +27,7 @@ HostMemoryAllocation::HostMemoryAllocation(void* ptr, uint64_t size, HostMemoryAllocation::~HostMemoryAllocation() { if (ptr_ != nullptr && executor_ != nullptr) { - executor_->HostMemoryDeallocate(ptr_, size_); + executor_->HostMemoryDeallocate(ptr_); } } diff --git a/xla/stream_executor/integrations/device_mem_allocator.h b/xla/stream_executor/integrations/device_mem_allocator.h index 8b31f8b6e5b291..736b62e051314a 100644 --- a/xla/stream_executor/integrations/device_mem_allocator.h +++ b/xla/stream_executor/integrations/device_mem_allocator.h @@ -82,7 +82,7 @@ class DeviceMemAllocator : public tsl::SubAllocator { auto status = stream_exec_->CollectiveMemoryDeallocate(ptr); CHECK(status.ok()) << status.message(); } else if (memory_type_ == MemoryType::kHost) { - stream_exec_->HostMemoryDeallocate(ptr, num_bytes); + stream_exec_->HostMemoryDeallocate(ptr); } else { DeviceMemoryBase device_ptr(ptr); stream_exec_->Deallocate(&device_ptr); diff --git a/xla/stream_executor/mock_stream_executor.h b/xla/stream_executor/mock_stream_executor.h index 03dd1115f3d6fb..9e4cdc08fcf62f 100644 --- a/xla/stream_executor/mock_stream_executor.h +++ b/xla/stream_executor/mock_stream_executor.h @@ -89,8 +89,7 @@ class MockStreamExecutor : public StreamExecutor { (override)); MOCK_METHOD(absl::StatusOr>, HostMemoryAllocate, (uint64_t size), (override)); - MOCK_METHOD(void, HostMemoryDeallocate, (void* mem, uint64_t size), - (override)); + MOCK_METHOD(void, HostMemoryDeallocate, (void* mem), (override)); MOCK_METHOD(bool, SynchronizeAllActivity, (), (override)); MOCK_METHOD(absl::Status, SynchronousMemZero, (DeviceMemoryBase * location, uint64_t size), (override)); diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index cb096c6485e10d..cf9fe323a9c939 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -447,20 +447,6 @@ void GpuExecutor::Deallocate(DeviceMemoryBase* mem) { GpuDriver::DeviceDeallocate(context_, mem->opaque()); } -absl::StatusOr> -GpuExecutor::HostMemoryAllocate(uint64_t size) { - auto* buffer = GpuDriver::HostAllocate(context_, size); - if (buffer == nullptr && size > 0) { - return absl::InternalError( - absl::StrFormat("Failed to allocate HostMemory of size %d", size)); - } - return std::make_unique(buffer, size, this); -} - -void GpuExecutor::HostMemoryDeallocate(void* location, uint64_t size) { - return GpuDriver::HostDeallocate(context_, location); -} - bool GpuExecutor::SynchronizeAllActivity() { return GpuDriver::SynchronizeContext(context_); } diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index 49929d4ce34c11..f5fb436dfd2274 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -200,7 +200,7 @@ class StreamExecutor { uint64_t size) = 0; // Deallocates a region of host memory allocated by HostMemoryAllocate(). - virtual void HostMemoryDeallocate(void* mem, uint64_t size) = 0; + virtual void HostMemoryDeallocate(void* mem) = 0; // Returns the memory space of the given pointer. virtual absl::StatusOr GetPointerMemorySpace(const void* ptr) { diff --git a/xla/stream_executor/tpu/tpu_executor.h b/xla/stream_executor/tpu/tpu_executor.h index c969c6d5d59d51..85646afbb68762 100644 --- a/xla/stream_executor/tpu/tpu_executor.h +++ b/xla/stream_executor/tpu/tpu_executor.h @@ -137,7 +137,7 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { uint64_t size) override { LOG(FATAL) << "not yet implemented"; } - void HostMemoryDeallocate(void* mem, uint64_t size) override { + void HostMemoryDeallocate(void* mem) override { LOG(FATAL) << "not yet implemented"; } absl::Status SynchronousMemZero(DeviceMemoryBase* location, From 92c41f5548e2dcea4858d3011bfd7c221508d896 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 1 Aug 2024 08:13:37 -0700 Subject: [PATCH 365/376] [xla:cpu] Compute ThunkExecutor node priorities to guide nodes execution It might make sense to traverse ready nodes according to their priority to create as much work as possible for concurrent thunk executor. In preparation for the change add a simple priority computed as the number of reachable nodes. PiperOrigin-RevId: 658406518 --- xla/service/cpu/runtime/thunk_executor.cc | 24 +++++++++++-------- xla/service/cpu/runtime/thunk_executor.h | 7 +++--- .../cpu/runtime/thunk_executor_test.cc | 15 ++++++++++++ 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index 805840ad855e93..bb2474535d0c9b 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -61,7 +61,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, } // Erase redundant edges between nodes. - int64_t num_erased_edges = TransitiveReduction(); + int64_t num_erased_edges = RunTransitiveReductionAndUpdatePriorities(); // Check if constructed execution DAG is sequential: every node depends on the // completion of the previous node. @@ -431,7 +431,7 @@ static int64_t EraseEdge(ThunkExecutor::NodeDef& from, return 0; } -int64_t ThunkExecutor::TransitiveReduction() { +int64_t ThunkExecutor::RunTransitiveReductionAndUpdatePriorities() { int64_t num_erased_edges = 0; // Keep workspace for DFS traversal between iterations. @@ -454,11 +454,11 @@ int64_t ThunkExecutor::TransitiveReduction() { stack.clear(); visited.assign(nodes_defs_.size(), false); - // Initialize stack with nodes reachable via immediate out nodes. We don't - // need to add source node and immediate out nodes to the visited set - // because graph is acyclic and we don't visit them again. + // Initialize stack with nodes reachable via immediate out nodes. We mark + // immediate out nodes as visited to correctly compute node priority below. for (int64_t out_id : source_node.out_edges) { NodeDef& out_node = nodes_defs_[out_id]; + visited[out_id] = true; for (int64_t start_id : out_node.out_edges) add_to_stack(start_id); } @@ -472,6 +472,9 @@ int64_t ThunkExecutor::TransitiveReduction() { for (int64_t out_id : node.out_edges) add_to_stack(out_id); } + + // Set node priority to the number of visited nodes in the DFS traversal. + source_node.priority = absl::c_count(visited, true); } return num_erased_edges; @@ -495,11 +498,12 @@ std::string ThunkExecutor::ToString() const { const Thunk& thunk = *thunk_sequence_[i]; bool is_source = absl::c_find(source_, i) != source_.end(); bool is_sink = absl::c_find(sink_, i) != sink_.end(); - absl::StrAppendFormat( - &str, - "\n thunk #%05d: op_name=%s, dependencies=[%s], source=%v, sink=%v", i, - thunk.info().op_name, absl::StrJoin(in_edges[i], ", "), is_source, - is_sink); + absl::StrAppendFormat(&str, + "\n thunk #%05d: op_name=%s, dependencies=[%s], " + "source=%v, sink=%v, priority=%d", + i, thunk.info().op_name, + absl::StrJoin(in_edges[i], ", "), is_source, is_sink, + nodes_defs_[i].priority); } return str; diff --git a/xla/service/cpu/runtime/thunk_executor.h b/xla/service/cpu/runtime/thunk_executor.h index 10df02c45a9383..539018e9e9db79 100644 --- a/xla/service/cpu/runtime/thunk_executor.h +++ b/xla/service/cpu/runtime/thunk_executor.h @@ -72,6 +72,7 @@ class ThunkExecutor { // NodeDef defines an execution order for all thunks in a sequence. struct NodeDef { NodeId id = kInvalidNodeId; + int64_t priority = 0; std::vector in_edges; std::vector out_edges; }; @@ -177,11 +178,11 @@ class ThunkExecutor { tsl::AsyncValuePtr node_event, ExecuteState::Node& node, ReadyQueue& ready_queue); - // Runs a transitive reduction on the NodeDef graph to remove redundant edges. - // Returns the number of removed edges. + // Runs a transitive reduction on the NodeDef graph to remove redundant edges, + // and updates nodes priorities. Returns the number of removed edges. // // See: https://en.wikipedia.org/wiki/Transitive_reduction - int64_t TransitiveReduction(); + int64_t RunTransitiveReductionAndUpdatePriorities(); ThunkSequence thunk_sequence_; Options options_; diff --git a/xla/service/cpu/runtime/thunk_executor_test.cc b/xla/service/cpu/runtime/thunk_executor_test.cc index 2bbb932a4a432c..78f4f7fa868a5d 100644 --- a/xla/service/cpu/runtime/thunk_executor_test.cc +++ b/xla/service/cpu/runtime/thunk_executor_test.cc @@ -237,6 +237,10 @@ TEST(ThunkExecutorTest, DependencyOrdering) { EXPECT_FALSE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0, 1)); EXPECT_THAT(executor.sink(), ElementsAre(2)); + + EXPECT_EQ(executor.node_def(0).priority, 1); + EXPECT_EQ(executor.node_def(1).priority, 1); + EXPECT_EQ(executor.node_def(2).priority, 0); } TEST(ThunkExecutorTest, SequentialOrdering) { @@ -255,6 +259,10 @@ TEST(ThunkExecutorTest, SequentialOrdering) { EXPECT_TRUE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0)); EXPECT_THAT(executor.sink(), ElementsAre(2)); + + EXPECT_EQ(executor.node_def(0).priority, 2); + EXPECT_EQ(executor.node_def(1).priority, 1); + EXPECT_EQ(executor.node_def(2).priority, 0); } TEST(ThunkExecutorTest, ResourceOrdering) { @@ -278,6 +286,9 @@ TEST(ThunkExecutorTest, ResourceOrdering) { EXPECT_TRUE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0)); EXPECT_THAT(executor.sink(), ElementsAre(1)); + + EXPECT_EQ(executor.node_def(0).priority, 1); + EXPECT_EQ(executor.node_def(1).priority, 0); } TEST(ThunkExecutorTest, TransitiveReduction) { @@ -300,6 +311,10 @@ TEST(ThunkExecutorTest, TransitiveReduction) { EXPECT_THAT(executor.node_def(1).in_edges, ElementsAre(0)); EXPECT_THAT(executor.node_def(1).out_edges, ElementsAre(2)); EXPECT_THAT(executor.node_def(2).in_edges, ElementsAre(1)); + + EXPECT_EQ(executor.node_def(0).priority, 2); + EXPECT_EQ(executor.node_def(1).priority, 1); + EXPECT_EQ(executor.node_def(2).priority, 0); } TEST(ThunkExecutorTest, Execute) { From 66caede9b97982e6f3d35f2910c2f70232173876 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 1 Aug 2024 08:33:54 -0700 Subject: [PATCH 366/376] [XLA:GPU] Enable pipeline parallelism test on TAP. PiperOrigin-RevId: 658412026 --- xla/tests/BUILD | 5 ----- 1 file changed, 5 deletions(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index f725acf5791289..da3554e6a0ad44 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2311,14 +2311,9 @@ xla_test( srcs = ["collective_pipeline_parallelism_test.cc"], args = ["--xla_force_host_platform_device_count=4"], backend_tags = { - # This test is tagged "manual" because it requires multiple GPUs, and Forge only supports - # single-GPU tests. Guitar skips "manual" tests unless they're also tagged "guitar". "gpu": [ - "guitar", - "manual", "multi_gpu", "no_oss", - "notap", ], "cpu": [ "notsan", From fc7da4035ebbee71049dd8830366a5c45064074c Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 1 Aug 2024 08:36:08 -0700 Subject: [PATCH 367/376] [XLA:GPU] Skip tests if too few devices Skip tests if too few devices for multi-GPU tests in collective_ops_e2e_test and replicated_io_feed_test. PiperOrigin-RevId: 658412608 --- xla/tests/collective_ops_e2e_test.cc | 15 +++++++++++++++ xla/tests/replicated_io_feed_test.cc | 3 +++ 2 files changed, 18 insertions(+) diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index f1d1c78d28bb61..99085382d07e8d 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -154,6 +154,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllReduce) { )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_reduce = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -190,6 +191,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllGather) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_gather = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, @@ -231,6 +233,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllGatherMixedTypes) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_gather = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, @@ -268,6 +271,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncCollectiveBroadcast) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_collective_broadcast = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -300,6 +304,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncCollectivePermute) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_collective_permute = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -343,6 +348,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncReduceScatter) { )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_reduce_scatter = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -376,6 +382,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithSplitDim) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_to_all = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -420,6 +427,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithoutSplitDim) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_to_all = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -472,6 +480,7 @@ TEST_P(AsyncCollectiveOps, MatmulReplicated) { } )"; const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -592,6 +601,7 @@ TEST_F(CollectiveOpsTestE2E, WhileLoopReduceScatterCodeMotion) { )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(true); @@ -646,6 +656,7 @@ TEST_F(CollectiveOpsTestE2E, NoAllToAllDecomposition) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -677,6 +688,7 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { absl::string_view hlo_text, bool disable_dot_merger = false) { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -959,6 +971,7 @@ ENTRY entry { )"; const int64_t kNumReplicas = 1; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const int64_t kNumPartitions = 4; HloModuleConfig config = @@ -1052,6 +1065,7 @@ ENTRY entry { )"; const int64_t kNumReplicas = 1; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1085,6 +1099,7 @@ ENTRY entry { )"; const int64_t kNumReplicas = 1; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const int64_t kNumPartitions = 4; HloModuleConfig config = diff --git a/xla/tests/replicated_io_feed_test.cc b/xla/tests/replicated_io_feed_test.cc index 9ee34a7a17da8d..0164f8b6b30e69 100644 --- a/xla/tests/replicated_io_feed_test.cc +++ b/xla/tests/replicated_io_feed_test.cc @@ -50,7 +50,10 @@ XLA_TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) { result = u32[] add(infeed.data, replica_id) outfeed = token[] outfeed(result, infeed.token), outfeed_shape=u32[] })"; + const int kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + auto config = GetModuleConfigForTest(); config.set_replica_count(kNumReplicas); std::unique_ptr module = From 280ff87738799281f62af010f3eae30b49ebf522 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Thu, 1 Aug 2024 08:55:56 -0700 Subject: [PATCH 368/376] Ensure that `dims_mapping.conv_spatial_dims.empty()` before calling `InferDotOperandSharding`. This change is a fix for cl/658203651 (https://github.com/openxla/xla/commit/f9011c1dbee841dbbff73e189614111db692973f). In cl/658203651, we reshard dot/conv operand into its expected sharding by calling `InferDotOperandSharding`. We need to ensure that `dims_mapping.conv_spatial_dims.empty()` before calling `InferDotOperandSharding`. If `dims_mapping.conv_spatial_dims.empty()`, we should directly proceed with the last resort (fully replicate the operand). PiperOrigin-RevId: 658418417 --- xla/service/spmd/dot_handler.cc | 9 +++++---- xla/service/spmd/spmd_partitioner_test.cc | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/xla/service/spmd/dot_handler.cc b/xla/service/spmd/dot_handler.cc index 16ace2aa39e86c..22f88cf0dad143 100644 --- a/xla/service/spmd/dot_handler.cc +++ b/xla/service/spmd/dot_handler.cc @@ -2562,7 +2562,8 @@ absl::StatusOr PartitionDotGroupOnNonContractingImpl( }; std::optional other_grouped = try_sharding_for_other_operand(other.sharding()); - if (!other_grouped && !other.sharding().IsReplicated()) { + if (!other_grouped && !other.sharding().IsReplicated() && + dims_mapping.conv_spatial_dims.empty()) { const HloSharding expected_other_sharding = hlo_sharding_util::InferDotOperandSharding( &output_sharding, &matching.sharding(), lhs_matching ? 1 : 0, @@ -2570,9 +2571,9 @@ absl::StatusOr PartitionDotGroupOnNonContractingImpl( // Try the expected sharding since it is no worse than the last resort // (replicated sharding). other_grouped = try_sharding_for_other_operand(expected_other_sharding); - if (!other_grouped) { - other = other.Replicate(); - } + } + if (!other_grouped) { + other = other.Replicate(); } matching = matching.Reshard(UngroupSharding(matching_grouped)); diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index e24e32f7899866..a52f68da226328 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -9163,6 +9163,29 @@ ENTRY main { EXPECT_THAT(root, op::AllReduce(dot)); } +TEST_P(SpmdPartitioningTest, ReplicateLHSofConv) { + const char* const hlo_string = R"( +HloModule module +ENTRY main { + lhs = bf16[128,8,8,1280] parameter(0), sharding={devices=[128,1,1,1]<=[128]} + rhs = bf16[3,3,1280,1280] parameter(1), sharding={devices=[1,1,1,8,16]<=[16,8]T(1,0) last_tile_dim_replicate} + ROOT conv = bf16[128,8,8,1280] convolution(lhs, rhs), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, sharding={devices=[1,1,1,8,16]<=[16,8]T(1,0) last_tile_dim_replicate} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, PartitionComputation(hlo_string, /*num_devices=*/128)); + VLOG(1) << module->ToString(); + + const auto lhs = AllOf(op::Shape("bf16[128,8,8,1280]"), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), op::Parameter(0), _, _, _, _))); + const auto rhs = AllOf(op::Shape("bf16[3,3,1280,160]"), op::Parameter(1)); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Shape("bf16[128,8,8,160]"), op::Convolution(lhs, rhs))); +} + TEST_P(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) { absl::string_view hlo_string = R"( HloModule module From adc78cb9be78b32f89f0b4251f52e5e5aa8247ad Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 1 Aug 2024 09:22:36 -0700 Subject: [PATCH 369/376] [XLA] Add a utility function to sort json strings. More specifically, this sorts the fields in json objects, first by key name and then by the string encoding of the value. This is not meant to be used for general JSON, but rather specifically to sort BackEnd Config Json strings. These need to be canonical because they are used as part of autotuning keys. PiperOrigin-RevId: 658426795 --- xla/BUILD | 27 +++++ xla/sort_json.cc | 257 ++++++++++++++++++++++++++++++++++++++++++ xla/sort_json.h | 35 ++++++ xla/sort_json_test.cc | 51 +++++++++ 4 files changed, 370 insertions(+) create mode 100644 xla/sort_json.cc create mode 100644 xla/sort_json.h create mode 100644 xla/sort_json_test.cc diff --git a/xla/BUILD b/xla/BUILD index 3b58f29ad59b27..2c159ed727e3ad 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -1312,6 +1312,33 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "sort_json", + srcs = ["sort_json.cc"], + hdrs = ["sort_json.h"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "sort_json_test", + srcs = ["sort_json_test.cc"], + deps = [ + ":sort_json", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + # Needed to workaround https://github.com/bazelbuild/bazel/issues/21519 alias( name = "bazel_issue_21519", diff --git a/xla/sort_json.cc b/xla/sort_json.cc new file mode 100644 index 00000000000000..aaa1e197a3fa26 --- /dev/null +++ b/xla/sort_json.cc @@ -0,0 +1,257 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/sort_json.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace { + +void SkipWhitespace(absl::string_view json, size_t& index) { + while (index < json.size() && std::isspace(json[index])) { + ++index; + } +} + +absl::Status CheckNotEndOfString(absl::string_view json, int index, + absl::string_view expected) { + return index < json.size() + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::StrCat( + "Prematurely reached end of JSON while looking for ", + expected, ".")); +} + +absl::Status Consume(absl::string_view json, size_t& index, char c, + bool optional = false) { + SkipWhitespace(json, index); + TF_RETURN_IF_ERROR(CheckNotEndOfString(json, index, std::string(1, c))); + if (json[index] == c) { + ++index; + SkipWhitespace(json, index); + } else if (!optional) { + return absl::InvalidArgumentError( + absl::StrCat("Expected '", std::string(1, c), "', but found '", + std::string(1, json[index]), "'.")); + } + return absl::OkStatus(); +} + +struct JsonArray; +struct JsonObject; + +using JsonValue = std::variant, + std::unique_ptr>; + +struct JsonField { + absl::string_view name; + JsonValue value; +}; + +template +struct JsonSequence { + std::vector elements; +}; + +struct JsonArray : public JsonSequence {}; +struct JsonObject : public JsonSequence {}; + +// This parses either an array or an object. +template +absl::StatusOr> ParseSequence(absl::string_view outer_json, + size_t& index, + ElemFn elem_fn) { + TF_RETURN_IF_ERROR(Consume(outer_json, index, begin)); + TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, name)); + + auto seq = std::make_unique(); + while (outer_json[index] != end) { + TF_ASSIGN_OR_RETURN(auto elem, elem_fn(outer_json, index)); + seq->elements.emplace_back(std::move(elem)); + TF_RETURN_IF_ERROR(Consume(outer_json, index, ',', /*optional=*/true)); + TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, name)); + } + TF_RETURN_IF_ERROR(Consume(outer_json, index, end)); + return seq; +} + +absl::Status EnsureValidLiteralStart(char c) { + if (c != '"' && c != '+' && c != '-' && c != 'f' && c != 't' && c != 'n' && + (c < '0' || c > '9')) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid first character of literal: '", std::string(1, c), "'.")); + } + return absl::OkStatus(); +} + +bool HandleEscape(absl::string_view outer_json, size_t& index, + bool& is_escaped) { + if (is_escaped) { + is_escaped = false; + ++index; + return true; + } + + if (outer_json[index] == '\\') { + is_escaped = true; + ++index; + return true; + } + return false; +} + +bool LiteralIsFinished(absl::string_view outer_json, size_t& index, + bool is_string_literal) { + char c = outer_json[index]; + if (is_string_literal) { + index += (c == '"' ? 1 : 0); + return c == '"'; + } + + return std::isspace(c) || c == ',' || c == '{' || c == '}' || c == '[' || + c == ']' || c == ':'; +} + +absl::StatusOr ParseLiteral(absl::string_view outer_json, + size_t& index) { + SkipWhitespace(outer_json, index); + TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, "literal")); + + auto c = outer_json[index]; + TF_RETURN_IF_ERROR(EnsureValidLiteralStart(c)); + bool is_string_literal = c == '"'; + size_t start_index = index; + bool is_escaped = false; + ++index; + + while (index < outer_json.size()) { + if (HandleEscape(outer_json, index, is_escaped)) { + continue; + } + if (LiteralIsFinished(outer_json, index, is_string_literal)) { + break; + } + ++index; + } + return outer_json.substr(start_index, index - start_index); +} + +absl::StatusOr ParseField(absl::string_view outer_json, + size_t& index); + +absl::StatusOr ParseValue(absl::string_view outer_json, + size_t& index) { + JsonValue value; + SkipWhitespace(outer_json, index); + TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, "value")); + auto c = outer_json[index]; + if (c == '{') { + constexpr static char kObject[] = "object"; + auto seq = ParseSequence(outer_json, index, + ParseField); + TF_ASSIGN_OR_RETURN(value, std::move(seq)); + } else if (c == '[') { + constexpr static char kArray[] = "array"; + auto seq = ParseSequence(outer_json, index, + ParseValue); + TF_ASSIGN_OR_RETURN(value, std::move(seq)); + } else { + TF_ASSIGN_OR_RETURN(value, ParseLiteral(outer_json, index)); + } + return value; +} + +absl::StatusOr ParseField(absl::string_view outer_json, + size_t& index) { + JsonField field; + TF_ASSIGN_OR_RETURN(field.name, ParseLiteral(outer_json, index)); + TF_RETURN_IF_ERROR(Consume(outer_json, index, ':')); + TF_ASSIGN_OR_RETURN(field.value, ParseValue(outer_json, index)); + return field; +} + +template +std::vector SerializedElements(const JsonSequence& seq) { + std::vector result; + for (const auto& field : seq.elements) { + result.push_back(""); + Serialize(field, result.back()); + } + return result; +} + +template +void Serialize(const JsonSequence& object, std::string& result) { + auto elems = SerializedElements(object); + if constexpr (std::is_same_v) { + std::sort(elems.begin(), elems.end()); + } + + result += begin_brace; + bool has_preceeding = false; + for (const auto& elem : elems) { + if (has_preceeding) { + result += ','; + } + result += elem; + has_preceeding = true; + } + result += end_brace; +} + +void Serialize(const JsonValue& value, std::string& result) { + if (auto* lit = std::get_if(&value)) { + absl::StrAppend(&result, *lit); + } else if (auto* object = std::get_if>(&value)) { + Serialize(**object, result); + } else if (auto* array = std::get_if>(&value)) { + Serialize(**array, result); + } +} + +void Serialize(const JsonField& field, std::string& result) { + absl::StrAppend(&result, field.name, ":"); + Serialize(field.value, result); +} + +} // namespace + +namespace xla { +absl::StatusOr SortJson(absl::string_view json) { + size_t index = 0; + TF_ASSIGN_OR_RETURN(auto value, ParseValue(json, index)); + SkipWhitespace(json, index); + if (index < json.size()) { + return absl::InvalidArgumentError("Found trailing characters in JSON."); + } + std::string result; + Serialize(value, result); + return result; +} +} // namespace xla diff --git a/xla/sort_json.h b/xla/sort_json.h new file mode 100644 index 00000000000000..b4283f556500ce --- /dev/null +++ b/xla/sort_json.h @@ -0,0 +1,35 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SORT_JSON_H_ +#define XLA_SORT_JSON_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace xla { + +// Sorts the given JSON string or returns an error if the JSON could not be +// parsed. Note that this function expects the input JSON to be valid and not +// all forms of invalid JSON are correctly recognized. This function completely +// ignores whitespace and the resulting JSON does not have any whitespace. +// Comments are not supported in the input JSON. +absl::StatusOr SortJson(absl::string_view json); + +} // namespace xla + +#endif // XLA_SORT_JSON_H_ diff --git a/xla/sort_json_test.cc b/xla/sort_json_test.cc new file mode 100644 index 00000000000000..f4ff0c1d785bc1 --- /dev/null +++ b/xla/sort_json_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/sort_json.h" + +#include +#include +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +TEST(SortJsonTest, SortsJson) { + EXPECT_THAT(SortJson(R"({"a": 1, "c": 3,"b": 2, "b": 1,})"), + IsOkAndHolds(R"({"a":1,"b":1,"b":2,"c":3})")); + + EXPECT_THAT(SortJson(R"({"a": 1 , "c": 1,"b": 1 })"), + IsOkAndHolds(R"({"a":1,"b":1,"c":1})")); + + EXPECT_THAT(SortJson(R"({"a": 1,"c": 3,"b": 2,"b": [3,2,1],})"), + IsOkAndHolds(R"({"a":1,"b":2,"b":[3,2,1],"c":3})")); + + EXPECT_THAT(SortJson(R"({"aa": 1, "a": {"c": "c", "b": "b"}})"), + IsOkAndHolds(R"({"a":{"b":"b","c":"c"},"aa":1})")); + + EXPECT_THAT( + SortJson( + R"({"x": true, "x": false, "x": null, "x": 0, "x": -0.5,"x": "a"})"), + IsOkAndHolds(R"({"x":"a","x":-0.5,"x":0,"x":false,"x":null,"x":true})")); + + EXPECT_THAT(SortJson(R"({"a": "a}", "a": "a"})"), + IsOkAndHolds(R"({"a":"a","a":"a}"})")); +} + +} // namespace +} // namespace xla From a865261e416cdc83ac9629049ca686f9dd244b30 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 1 Aug 2024 09:25:19 -0700 Subject: [PATCH 370/376] [XLA:GPU] Enable more multi-GPU tests on TAP PiperOrigin-RevId: 658427789 --- xla/tests/BUILD | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/xla/tests/BUILD b/xla/tests/BUILD index da3554e6a0ad44..f71b3ab06d1841 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -2352,15 +2352,9 @@ xla_test( name = "collective_ops_e2e_test", srcs = ["collective_ops_e2e_test.cc"], backend_tags = { - # This test is tagged "manual" because it requires multiple GPUs, and - # Forge only supports single-GPU tests. Guitar skips "manual" tests - # unless they're also tagged "guitar". "gpu": [ - "guitar", - "manual", "multi_gpu", "no_oss", - "notap", ], }, backends = [ @@ -2404,15 +2398,9 @@ xla_test( name = "replicated_io_feed_test", srcs = ["replicated_io_feed_test.cc"], backend_tags = { - # This test is tagged "manual" because it requires multiple GPUs, and - # Forge only supports single-GPU tests. Guitar skips "manual" tests - # unless they're also tagged "guitar". "gpu": [ - "guitar", - "manual", "multi_gpu", "no_oss", - "notap", ], }, backends = ["gpu"], From 341366e63dcefe86264ade207743b9b3fa3b6169 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 1 Aug 2024 09:28:49 -0700 Subject: [PATCH 371/376] [XLA:GPU] Add forgotten template arg PiperOrigin-RevId: 658429198 --- xla/tests/collective_pipeline_parallelism_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index 3db920ad6f75f2..e703d019199694 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -250,7 +250,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { // Check pipeline output for last replica. // The combined effect of the pipeline is to scale the input data by 24.0. const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0; - Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( kMicrobatches, kInputSize, kExpectedFactor); EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], ErrorSpec{1e-5, 1e-5})); From 68f8065af1e02b76e0fd35b3041ec303bba3fe8c Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Thu, 1 Aug 2024 09:36:44 -0700 Subject: [PATCH 372/376] Integrate LLVM at llvm/llvm-project@e1451236a0a0 Updates LLVM usage to match [e1451236a0a0](https://github.com/llvm/llvm-project/commit/e1451236a0a0) PiperOrigin-RevId: 658431851 --- third_party/llvm/workspace.bzl | 4 ++-- third_party/shardy/temporary.patch | 10 +++++----- third_party/shardy/workspace.bzl | 4 ++-- third_party/tsl/third_party/llvm/workspace.bzl | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index e0f0b45d1f46d4..bb32ac1a48e349 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "17ba4f4053e303be3e5408d34eaf687a49cefb06" - LLVM_SHA256 = "64c334c15f058ca090fe8eb1e2cd99fdcbaaaf6e57202760f32574d3a9d24d78" + LLVM_COMMIT = "e1451236a0a07f1ee4ba5fe3ae2464a82a37c25c" + LLVM_SHA256 = "22ade2cbcd9df84196461948456b6c965b536b8c25de92a9dc546534f550b63d" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index aa892f8ca2205b..01b2cd6c7935b3 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +1,15 @@ diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 6429d9b..e0f0b45 100644 +index e0f0b45..bb32ac1 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "42d641ef5cc4bd82f98ef9959a593ca6db66d75d" -- LLVM_SHA256 = "ec368e9c3b1e1c5eb646c21da65bb54a53060b417e61f2451f3917b35d743abd" -+ LLVM_COMMIT = "17ba4f4053e303be3e5408d34eaf687a49cefb06" -+ LLVM_SHA256 = "64c334c15f058ca090fe8eb1e2cd99fdcbaaaf6e57202760f32574d3a9d24d78" +- LLVM_COMMIT = "17ba4f4053e303be3e5408d34eaf687a49cefb06" +- LLVM_SHA256 = "64c334c15f058ca090fe8eb1e2cd99fdcbaaaf6e57202760f32574d3a9d24d78" ++ LLVM_COMMIT = "e1451236a0a07f1ee4ba5fe3ae2464a82a37c25c" ++ LLVM_SHA256 = "22ade2cbcd9df84196461948456b6c965b536b8c25de92a9dc546534f550b63d" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index dc1800d246361c..416cc79d485a5d 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "a0d337eecb4957da862b235a2829efd9513a129c" - SHARDY_SHA256 = "ffbc55d51995da6fd149e7bc0e41bee4faba02dfa2984f2af54940c46578c0c7" + SHARDY_COMMIT = "94e058dec0365df4f6af97367e2dd678d2041ade" + SHARDY_SHA256 = "970931265428bad56d45ef57562115c5a39c9b0c99ff3b9dab3aaa4bb98bed3c" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index e0f0b45d1f46d4..bb32ac1a48e349 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "17ba4f4053e303be3e5408d34eaf687a49cefb06" - LLVM_SHA256 = "64c334c15f058ca090fe8eb1e2cd99fdcbaaaf6e57202760f32574d3a9d24d78" + LLVM_COMMIT = "e1451236a0a07f1ee4ba5fe3ae2464a82a37c25c" + LLVM_SHA256 = "22ade2cbcd9df84196461948456b6c965b536b8c25de92a9dc546534f550b63d" tf_http_archive( name = name, From 4b3c6570ec20360f0230f3003682d3d619bdb6ed Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 1 Aug 2024 10:04:51 -0700 Subject: [PATCH 373/376] Remove unused code in various Platform classes. PiperOrigin-RevId: 658441225 --- xla/stream_executor/cuda/cuda_platform.cc | 73 ++-------------------- xla/stream_executor/cuda/cuda_platform.h | 23 ------- xla/stream_executor/rocm/rocm_platform.cc | 75 ++--------------------- xla/stream_executor/rocm/rocm_platform.h | 22 ------- xla/stream_executor/sycl/BUILD | 1 + xla/stream_executor/sycl/sycl_platform.cc | 68 ++------------------ xla/stream_executor/sycl/sycl_platform.h | 22 ------- 7 files changed, 15 insertions(+), 269 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_platform.cc b/xla/stream_executor/cuda/cuda_platform.cc index bdace571118435..ea86363ce27e9f 100644 --- a/xla/stream_executor/cuda/cuda_platform.cc +++ b/xla/stream_executor/cuda/cuda_platform.cc @@ -15,19 +15,14 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform.h" -#include -#include -#include #include #include #include -#include "absl/base/call_once.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_driver.h" @@ -35,65 +30,16 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" namespace stream_executor { namespace gpu { -CudaPlatform::CudaPlatform() - : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {} +CudaPlatform::CudaPlatform() : name_("CUDA") {} CudaPlatform::~CudaPlatform() {} -// Due to legacy issues in user code, we can't currently call InpectNumaNodes -// at module initialization time, because non-GPU programs still include this -// plugin via various methods, so instead, it has to be init-on-reference. -void CudaPlatform::InspectNumaNodes() { - // To get NUMA node information, we need to create all executors, so we can - // examine their device descriptions to see their bus assignments. - static absl::once_flag once; - absl::call_once(once, [&] { - for (int i = 0; i < VisibleDeviceCount(); i++) { - StreamExecutor* exec = *ExecutorForDevice(i); - if (i == 0) { - // NUMA nodes may not start at 0, so set the minimum node based on the - // first executor we see. - min_numa_node_ = exec->GetDeviceDescription().numa_node(); - limit_numa_node_ = min_numa_node_ + 1; - } else { - min_numa_node_ = - std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); - limit_numa_node_ = std::max( - limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1); - } - } - }); -} - -int CudaPlatform::BusCount() { - InspectNumaNodes(); - return limit_numa_node_ - min_numa_node_; -} - -int CudaPlatform::DeviceToBus(int device_ordinal) { - StreamExecutor* exec = *ExecutorForDevice(device_ordinal); - return exec->GetDeviceDescription().numa_node() - min_numa_node_; -} - -absl::StatusOr CudaPlatform::FirstExecutorForBus( - int bus_ordinal) { - InspectNumaNodes(); - CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range"; - for (int i = 0; i < VisibleDeviceCount(); i++) { - if (DeviceToBus(i) == bus_ordinal) { - return *ExecutorForDevice(i); - } - } - - return absl::NotFoundError( - absl::StrFormat("Executor for bus %d not found.", bus_ordinal)); -} - Platform::Id CudaPlatform::id() const { return cuda::kCudaPlatformId; } int CudaPlatform::VisibleDeviceCount() const { @@ -133,24 +79,15 @@ absl::StatusOr CudaPlatform::GetExecutor( absl::StatusOr> CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto executor = std::make_unique(this, config.ordinal); - auto init_status = executor->Init(); - if (!init_status.ok()) { - return absl::InternalError(absl::StrFormat( - "failed initializing StreamExecutor for CUDA device ordinal %d: %s", - config.ordinal, init_status.ToString())); - } - + TF_RETURN_IF_ERROR(executor->Init()); return std::move(executor); } } // namespace gpu static void InitializeCudaPlatform() { - // Disabling leak checking, PlatformManager does not destroy its - // registered platforms. - - std::unique_ptr platform(new gpu::CudaPlatform); - TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK( + PlatformManager::RegisterPlatform(std::make_unique())); } } // namespace stream_executor diff --git a/xla/stream_executor/cuda/cuda_platform.h b/xla/stream_executor/cuda/cuda_platform.h index 153282b26507e6..cd0004767f1d18 100644 --- a/xla/stream_executor/cuda/cuda_platform.h +++ b/xla/stream_executor/cuda/cuda_platform.h @@ -22,7 +22,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" namespace stream_executor { @@ -41,16 +40,6 @@ class CudaPlatform : public Platform { CudaPlatform(); ~CudaPlatform() override; - // CudaPlatform-specific functionality - // Returns the number of distinct buses / NUMA nodes on the machine. - int BusCount(); - - // Returns the bus/NUMA node for the specified device ordinal. - int DeviceToBus(int device_ordinal); - - // Returns the lowest-ordinal-number StreamExecutor on the specified bus. - absl::StatusOr FirstExecutorForBus(int bus_ordinal); - // Platform interface implementation: // Returns the same value as kCudaPlatform above. Platform::Id id() const override; @@ -72,24 +61,12 @@ class CudaPlatform : public Platform { const StreamExecutorConfig& config) override; private: - // Determines the number of NUMA nodes and the assignment of executor to each. - void InspectNumaNodes(); - // This platform's name. std::string name_; // Cache of created executors. ExecutorCache executor_cache_; - // The smallest NUMA node value for any device managed by this machine - // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus - // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./ - int min_numa_node_; - - // Larger than the NUMA node value for any device managed by this machine - // manager. - int limit_numa_node_; - CudaPlatform(const CudaPlatform&) = delete; void operator=(const CudaPlatform&) = delete; }; diff --git a/xla/stream_executor/rocm/rocm_platform.cc b/xla/stream_executor/rocm/rocm_platform.cc index 0ac3540c4e627d..ef7bc09be0c6e7 100644 --- a/xla/stream_executor/rocm/rocm_platform.cc +++ b/xla/stream_executor/rocm/rocm_platform.cc @@ -28,67 +28,10 @@ limitations under the License. namespace stream_executor { namespace gpu { -ROCmPlatform::ROCmPlatform() - : name_("ROCM"), min_numa_node_(0), limit_numa_node_(0) {} +ROCmPlatform::ROCmPlatform() : name_("ROCM") {} ROCmPlatform::~ROCmPlatform() {} -// Due to legacy issues in user code, we can't currently call InpectNumaNodes -// at module initialization time, because non-GPU programs still include this -// plugin via various methods, so instead, it has to be init-on-reference. -void ROCmPlatform::InspectNumaNodes() { - // To get NUMA node information, we need to create all executors, so we can - // examine their device descriptions to see their bus assignments. - absl::once_flag once; - absl::call_once(once, [&] { - StreamExecutorConfig config; - for (int i = 0; i < VisibleDeviceCount(); i++) { - config.ordinal = i; - StreamExecutor* exec = GetExecutor(config).value(); - if (i == 0) { - // NUMA nodes may not start at 0, so set the minimum node based on the - // first executor we see. - min_numa_node_ = exec->GetDeviceDescription().numa_node(); - limit_numa_node_ = min_numa_node_ + 1; - } else { - min_numa_node_ = - std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); - limit_numa_node_ = std::max( - limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1); - } - } - }); -} - -int ROCmPlatform::BusCount() { - InspectNumaNodes(); - return limit_numa_node_ - min_numa_node_; -} - -int ROCmPlatform::DeviceToBus(int device_ordinal) { - StreamExecutorConfig config; - config.ordinal = device_ordinal; - StreamExecutor* exec = GetExecutor(config).value(); - return exec->GetDeviceDescription().numa_node() - min_numa_node_; -} - -absl::StatusOr ROCmPlatform::FirstExecutorForBus( - int bus_ordinal) { - InspectNumaNodes(); - CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range"; - for (int i = 0; i < VisibleDeviceCount(); i++) { - if (DeviceToBus(i) == bus_ordinal) { - StreamExecutorConfig config; - config.ordinal = i; - return GetExecutor(config).value(); - } - } - - return absl::Status{ - absl::StatusCode::kNotFound, - absl::StrFormat("Executor for bus %d not found.", bus_ordinal)}; -} - Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; } int ROCmPlatform::VisibleDeviceCount() const { @@ -130,27 +73,17 @@ absl::StatusOr ROCmPlatform::GetExecutor( absl::StatusOr> ROCmPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto executor = std::make_unique(this, config.ordinal); - auto init_status = executor->Init(); - if (!init_status.ok()) { - return absl::Status{ - absl::StatusCode::kInternal, - absl::StrFormat( - "failed initializing StreamExecutor for ROCM device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())}; - } - + TF_RETURN_IF_ERROR(executor->Init()); return std::move(executor); } } // namespace gpu static void InitializeROCmPlatform() { - // Disabling leak checking, PlatformManager does not destroy its - // registered platforms. auto status = PlatformManager::PlatformWithName("ROCM"); if (!status.ok()) { - std::unique_ptr platform(new gpu::ROCmPlatform); - TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform( + std::make_unique())); } } diff --git a/xla/stream_executor/rocm/rocm_platform.h b/xla/stream_executor/rocm/rocm_platform.h index 6d18cf4902dcda..2fc54d15e71de4 100644 --- a/xla/stream_executor/rocm/rocm_platform.h +++ b/xla/stream_executor/rocm/rocm_platform.h @@ -41,16 +41,6 @@ class ROCmPlatform : public Platform { ROCmPlatform(); ~ROCmPlatform() override; - // ROCmPlatform-specific functionality - // Returns the number of distinct buses / NUMA nodes on the machine. - int BusCount(); - - // Returns the bus/NUMA node for the specified device ordinal. - int DeviceToBus(int device_ordinal); - - // Returns the lowest-ordinal-number StreamExecutor on the specified bus. - absl::StatusOr FirstExecutorForBus(int bus_ordinal); - // Platform interface implementation: // Returns the same value as kROCmPlatform above. Platform::Id id() const override; @@ -72,9 +62,6 @@ class ROCmPlatform : public Platform { const StreamExecutorConfig& config) override; private: - // Determines the number of NUMA nodes and the assignment of executor to each. - void InspectNumaNodes(); - // This platform's name. std::string name_; @@ -84,15 +71,6 @@ class ROCmPlatform : public Platform { // Cache of created executors. ExecutorCache executor_cache_; - // The smallest NUMA node value for any device managed by this machine - // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus - // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./ - int min_numa_node_; - - // Larger than the NUMA node value for any device managed by this machine - // manager. - int limit_numa_node_; - ROCmPlatform(const ROCmPlatform&) = delete; void operator=(const ROCmPlatform&) = delete; }; diff --git a/xla/stream_executor/sycl/BUILD b/xla/stream_executor/sycl/BUILD index 082946c6cd6b9a..5938f0aaf36320 100644 --- a/xla/stream_executor/sycl/BUILD +++ b/xla/stream_executor/sycl/BUILD @@ -48,6 +48,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_collectives_header", + "@tsl//tsl/platform:errors", ]), alwayslink = True, # Registers itself with the PlatformManager. ) diff --git a/xla/stream_executor/sycl/sycl_platform.cc b/xla/stream_executor/sycl/sycl_platform.cc index 876775b5d3df05..ac6da36a5ea559 100644 --- a/xla/stream_executor/sycl/sycl_platform.cc +++ b/xla/stream_executor/sycl/sycl_platform.cc @@ -35,65 +35,16 @@ limitations under the License. #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/sycl/sycl_platform_id.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" namespace stream_executor { namespace gpu { -SyclPlatform::SyclPlatform() - : name_("SYCL"), min_numa_node_(0), limit_numa_node_(0) {} +SyclPlatform::SyclPlatform() : name_("SYCL") {} SyclPlatform::~SyclPlatform() {} -// Due to legacy issues in user code, we can't currently call InspectNumaNodes -// at module initialization time, because non-GPU programs still include this -// plugin via various methods, so instead, it has to be init-on-reference. -void SyclPlatform::InspectNumaNodes() { - // To get NUMA node information, we need to create all executors, so we can - // examine their device descriptions to see their bus assignments. - static absl::once_flag once; - absl::call_once(once, [&] { - for (int i = 0; i < VisibleDeviceCount(); i++) { - StreamExecutor* exec = *ExecutorForDevice(i); - if (i == 0) { - // NUMA nodes may not start at 0, so set the minimum node based on the - // first executor we see. - min_numa_node_ = exec->GetDeviceDescription().numa_node(); - limit_numa_node_ = min_numa_node_ + 1; - } else { - min_numa_node_ = - std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); - limit_numa_node_ = std::max( - limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1); - } - } - }); -} - -int SyclPlatform::BusCount() { - InspectNumaNodes(); - return limit_numa_node_ - min_numa_node_; -} - -int SyclPlatform::DeviceToBus(int device_ordinal) { - StreamExecutor* exec = *ExecutorForDevice(device_ordinal); - return exec->GetDeviceDescription().numa_node() - min_numa_node_; -} - -absl::StatusOr SyclPlatform::FirstExecutorForBus( - int bus_ordinal) { - InspectNumaNodes(); - CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range"; - for (int i = 0; i < VisibleDeviceCount(); i++) { - if (DeviceToBus(i) == bus_ordinal) { - return *ExecutorForDevice(i); - } - } - - return absl::NotFoundError( - absl::StrFormat("Executor for bus %d not found.", bus_ordinal)); -} - Platform::Id SyclPlatform::id() const { return sycl::kSyclPlatformId; } int SyclPlatform::VisibleDeviceCount() const { @@ -133,24 +84,15 @@ absl::StatusOr SyclPlatform::GetExecutor( absl::StatusOr> SyclPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto executor = std::make_unique(this, config.ordinal); - auto init_status = executor->Init(); - if (!init_status.ok()) { - return absl::InternalError(absl::StrFormat( - "failed initializing StreamExecutor for SYCL device ordinal %d: %s", - config.ordinal, init_status.ToString())); - } - + TF_RETURN_IF_ERROR(executor->Init()); return std::move(executor); } } // namespace gpu static void InitializeSyclPlatform() { - // Disabling leak checking, PlatformManager does not destroy its - // registered platforms. - - std::unique_ptr platform(new gpu::SyclPlatform); - TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK( + PlatformManager::RegisterPlatform(std::make_unique())); } } // namespace stream_executor diff --git a/xla/stream_executor/sycl/sycl_platform.h b/xla/stream_executor/sycl/sycl_platform.h index 0c687f4eee1179..ac164fccb5b398 100644 --- a/xla/stream_executor/sycl/sycl_platform.h +++ b/xla/stream_executor/sycl/sycl_platform.h @@ -41,16 +41,6 @@ class SyclPlatform : public Platform { SyclPlatform(); ~SyclPlatform() override; - // SyclPlatform-specific functionality - // Returns the number of distinct buses / NUMA nodes on the machine. - int BusCount(); - - // Returns the bus/NUMA node for the specified device ordinal. - int DeviceToBus(int device_ordinal); - - // Returns the lowest-ordinal-number StreamExecutor on the specified bus. - absl::StatusOr FirstExecutorForBus(int bus_ordinal); - // Platform interface implementation: // Returns the same value as kSyclPlatform above. Platform::Id id() const override; @@ -72,24 +62,12 @@ class SyclPlatform : public Platform { const StreamExecutorConfig& config) override; private: - // Determines the number of NUMA nodes and the assignment of executor to each. - void InspectNumaNodes(); - // This platform's name. std::string name_; // Cache of created executors. ExecutorCache executor_cache_; - // The smallest NUMA node value for any device managed by this machine - // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus - // ordinals. The NUMA node space occupied by GPUs is assumed to be dense. - int min_numa_node_; - - // Larger than the NUMA node value for any device managed by this machine - // manager. - int limit_numa_node_; - SyclPlatform(const SyclPlatform&) = delete; void operator=(const SyclPlatform&) = delete; }; From 336c9a99eeebb04a4ad9a8e6ce3b8d96e5ecd5e7 Mon Sep 17 00:00:00 2001 From: akhilgoe <114951738+akhilgoe@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:05:45 -0700 Subject: [PATCH 374/376] PR #15252: [XLA:CPU][oneDNN] Fix oneDNN matmul test timeout Imported from GitHub PR https://github.com/openxla/xla/pull/15252 This PR addresses the test timeout observed in oneDNN matmul test file. In particular, this PR: 1. Replaces the test titled ConsecutiveBinaryAdd with a smaller test such that the test still hits the targeted failure case. 2. Shards the test file. 3. In addition, this PR also replaces all instances of the old proto definitions with the new ones. Copybara import of the project: -- 03e22ab3d7a20081dffcb3c1c6a9d96232b3bfd9 by Akhil Goel : Fix test timeout -- c3e08a6038da4d827a7ff042c0baec1699c684a9 by Akhil Goel : Address review comments Merging this change closes #15252 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15252 from Intel-tensorflow:akhil/fix_mm_timeout c3e08a6038da4d827a7ff042c0baec1699c684a9 PiperOrigin-RevId: 658441635 --- xla/service/cpu/tests/BUILD | 1 + xla/service/cpu/tests/onednn_matmul_test.cc | 54 ++++++--------------- 2 files changed, 16 insertions(+), 39 deletions(-) diff --git a/xla/service/cpu/tests/BUILD b/xla/service/cpu/tests/BUILD index 7f8076e52608cc..7e6dfd0e697e7b 100644 --- a/xla/service/cpu/tests/BUILD +++ b/xla/service/cpu/tests/BUILD @@ -361,6 +361,7 @@ xla_cc_test( name = "onednn_matmul_test", srcs = ["onednn_matmul_test.cc"], copts = tsl_copts(), + shard_count = 4, tags = [ "no_oss", "notap", diff --git a/xla/service/cpu/tests/onednn_matmul_test.cc b/xla/service/cpu/tests/onednn_matmul_test.cc index d7fb39f0d33a90..8c877fb206357e 100644 --- a/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/xla/service/cpu/tests/onednn_matmul_test.cc @@ -803,7 +803,9 @@ TEST_F(MatmulTest, TestNonScalarConstantEltwiseLinearF32) { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-NOT: "fused_ops":["LINEAR"] + ; CHECK-NOT: "fusions":{ + ; CHECK-NOT: "ops":["LINEAR"] + ; CHECK-NOT: } ; CHECK-DAG: } ; CHECK: } )"); @@ -1502,44 +1504,18 @@ TEST_F(MatmulTest, WeightsPrepackAndScratch) { TEST_F(MatmulTest, ConsecutiveBinaryAdd) { const char* matmul_module_str = R"( HloModule matmul.test.f32 - region_0.22 { - Arg_0.23 = f32[] parameter(0) - Arg_1.24 = f32[] parameter(1) - ROOT add.25 = f32[] add(Arg_0.23, Arg_1.24) - } - - region_1.29 { - Arg_0.30 = f32[] parameter(0) - Arg_1.31 = f32[] parameter(1) - ROOT add.32 = f32[] add(Arg_0.30, Arg_1.31) - } - - ENTRY main { - constant.2 = f32[] constant(1e-06) - broadcast.3 = f32[1000000] broadcast(constant.2), dimensions={} - constant.7 = f32[] constant(1) - broadcast.8 = f32[1000000,3] broadcast(constant.7), dimensions={} - Arg_0.1 = f32[3] parameter(0) - reshape.10 = f32[1,3] reshape(Arg_0.1) - broadcast.11 = f32[1,3] broadcast(reshape.10), dimensions={0,1} - reshape.12 = f32[3] reshape(broadcast.11) - broadcast.13 = f32[1000000,3] broadcast(reshape.12), dimensions={1} - subtract.14 = f32[1000000,3] subtract(broadcast.8, broadcast.13) - constant.4 = f32[] constant(0) - broadcast.5 = f32[3,3] broadcast(constant.4), dimensions={} - dot.15 = f32[1000000,3] dot(subtract.14, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot.16 = f32[1000000,3] dot(broadcast.3, dot.15), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={} - dot.17 = f32[1000000,3] dot(broadcast.3, subtract.14), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={} - dot.18 = f32[1000000,3] dot(dot.17, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={1} - add.19 = f32[1000000,3] add(dot.16, dot.18) - constant.9 = f32[3] constant({1, 2, 3}) - dot.20 = f32[1000000,3] dot(broadcast.3, constant.9), lhs_contracting_dims={}, rhs_contracting_dims={} - add.21 = f32[1000000,3] add(add.19, dot.20) - constant.6 = f32[] constant(0) - reduce.26 = f32[3] reduce(add.21, constant.6), dimensions={0}, to_apply=region_0.22 - reshape.27 = f32[1,3] reshape(reduce.26) - negate.28 = f32[1,3] negate(reshape.27) - ROOT reduce.33 = f32[3] reduce(negate.28, constant.6), dimensions={0}, to_apply=region_1.29 + ENTRY matmul.test.f32 { + arg0.1 = f32[128,32,4,4] parameter(0) + arg0.2 = f32[128,32,4,4] parameter(1) + dot.7 = f32[128,32,4,4] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + const.0 = f32[128,32] constant({...}) + bcast.1 = f32[128,32,4,4] broadcast(const.0), dimensions={0,1} + add.0 = f32[128,32,4,4] add(dot.7,bcast.1) + const.1 = f32[4] constant({1,2,3,4}) + bcast.2 = f32[128,32,4,4] broadcast(const.1), dimensions={3} + add.1 = f32[128,32,4,4] add(add.0, bcast.2) + tuple.12 = (f32[128,32,4,4]) tuple(add.1) + ROOT get-tuple-element.13 = f32[128,32,4,4] get-tuple-element(tuple.12), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); From 06098422b51878457b6dca8a2e360c1a064c326b Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 1 Aug 2024 10:06:23 -0700 Subject: [PATCH 375/376] [XLA:GPU] Share common HLO computations between some of the pipeline tests PiperOrigin-RevId: 658441952 --- .../collective_pipeline_parallelism_test.cc | 271 +++++++----------- 1 file changed, 109 insertions(+), 162 deletions(-) diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index e703d019199694..c41a194fded9e2 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include @@ -546,6 +547,59 @@ XLA_TEST_F(CollectivePipelineParallelismTest, ErrorSpec{1e-5, 1e-5})); } +std::string GetModuleStrWithCommonComputations( + const std::string name, const std::string more_computations) { + static constexpr char kCommonComputationsStr[] = R"( + read_buffer_mb5 { + buffer = f32[5,16] parameter(0) + offset = u32[] parameter(1) + index = u32[] parameter(2) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + slice = f32[1,16] dynamic-slice(buffer, index__, c0), + dynamic_slice_sizes={1,16} + ROOT slice_ = f32[16] reshape(slice) + } + + update_buffer_mb5 { + buffer = f32[5,16] parameter(0) + update = f32[16] parameter(1) + offset = u32[] parameter(2) + index = u32[] parameter(3) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + update_ = f32[1,16] reshape(update) + ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0) + } + + is_input_replica { + replica_id = u32[] replica-id() + c0 = u32[] constant(0) + ROOT predicate = pred[] compare(replica_id, c0), direction=EQ + } + + is_output_replica { + replica_id = u32[] replica-id() + c3 = u32[] constant(3) + ROOT predicate = pred[] compare(replica_id, c3), direction=EQ + } + + is_read_input_mb5 { + is_input_replica = pred[] call(), to_apply=is_input_replica + i = u32[] parameter(0) + c5 = u32[] constant(5) + is_input_iteration = pred[] compare(i, c5), direction=LT + ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) + } + )"; + return "HloModule " + name + "\n" + kCommonComputationsStr + "\n" + + more_computations; +} + // Naive implementation if pipeline parallelism: // - 4 devices // - 5 microbatches @@ -556,65 +610,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest, // Every stage of the pipeline is a single linear layer. XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5CircularRepeat2Replica4) { - const absl::string_view kModuleStr = R"( - HloModule test - - get_circ_buffer_index { - offset = u32[] parameter(0) - index = u32[] parameter(1) - size = u32[] parameter(2) - t0 = u32[] add(offset, index) - t1 = u32[] divide(t0, size) - t2 = u32[] multiply(t1, size) - ROOT t4 = u32[] subtract(t0, t2) - } - - read_buffer { - buffer = f32[5,16] parameter(0) - offset = u32[] parameter(1) - index = u32[] parameter(2) - c0 = u32[] constant(0) - c5 = u32[] constant(5) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c5) - slice = f32[1,16] dynamic-slice(buffer, index__, c0), - dynamic_slice_sizes={1,16} - ROOT slice_ = f32[16] reshape(slice) - } - - update_buffer { - buffer = f32[5,16] parameter(0) - update = f32[16] parameter(1) - offset = u32[] parameter(2) - index = u32[] parameter(3) - c0 = u32[] constant(0) - c5 = u32[] constant(5) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c5) - update_ = f32[1,16] reshape(update) - ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0) - } - - is_input_replica { - replica_id = u32[] replica-id() - c0 = u32[] constant(0) - ROOT predicate = pred[] compare(replica_id, c0), direction=EQ - } - - is_output_replica { - replica_id = u32[] replica-id() - c3 = u32[] constant(3) - ROOT predicate = pred[] compare(replica_id, c3), direction=EQ - } - - is_read_input { - is_input_replica = pred[] call(), to_apply=is_input_replica - i = u32[] parameter(0) - c5 = u32[] constant(5) - is_input_iteration = pred[] compare(i, c5), direction=LT - ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) - } - + constexpr char kMoreComputationsStr[] = R"( while_condition { tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) parameter(0) @@ -630,43 +626,46 @@ XLA_TEST_F(CollectivePipelineParallelismTest, input = f32[5,16] get-tuple-element(tuple), index=1 output = f32[5,16] get-tuple-element(tuple), index=2 buffer = f32[5,16] get-tuple-element(tuple), index=3 - prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4 i = u32[] get-tuple-element(tuple), index=5 c0 = u32[] constant(0) c1 = u32[] constant(1) c2 = u32[] constant(2) c3 = u32[] constant(3) + c4 = u32[] constant(4) c5 = u32[] constant(5) - input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index - input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), - dynamic_slice_sizes={1,16} - input_slice_ = f32[16] reshape(input_slice) - - buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer + // Read from buffers. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5 + buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5 + // Shift data to the next stage in the pipeline. + // Directly depends on the updated buffer of the previous iteration and, + // therefore, depends on the previous iteration's compute. is_output_replica = pred[] call(), to_apply=is_output_replica next_stage_slice = select(is_output_replica, buffer_slice, - prev_iteration_compute_out) - + prev_iteration_compute_res) prev_stage_slice = f32[16] collective-permute(next_stage_slice), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} - is_read_input = pred[] call(i), to_apply=is_read_input - compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice) - - compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + // Select compute argument from previous stage or from input and perform + // compute. + is_read_input = pred[] call(i), to_apply=is_read_input_mb5 + compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice) + compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1}, rhs_contracting_dims={0} - output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer - - buffer_ = f32[5,16] call(buffer, compute_out, c0, i), to_apply=update_buffer + // Update buffers. + output_ = f32[5,16] call(output, compute_res, c2, i), + to_apply=update_buffer_mb5 + buffer_ = f32[5,16] call(buffer, compute_res, c0, i), + to_apply=update_buffer_mb5 i_ = add(i, c1) ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) - tuple(weights, input, output_, buffer_, compute_out, i_) + tuple(weights, input, output_, buffer_, compute_res, i_) } ENTRY main { @@ -676,11 +675,12 @@ XLA_TEST_F(CollectivePipelineParallelismTest, cf0 = f32[] constant(0) output = f32[5,16] broadcast(cf0), dimensions={} buffer = f32[5,16] broadcast(cf0), dimensions={} - prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} c0 = u32[] constant(0) + // Iterate through pipeline stages. tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) - tuple(weights, input, output, buffer, prev_iteration_compute_out, c0) + tuple(weights, input, output, buffer, prev_iteration_compute_res, c0) tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) while(tuple), condition=while_condition, body=while_body @@ -693,8 +693,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); // This pipeline consists of a total of 8 layers (2 per replica), each of // which is a single linear layer. We assign the weights to the replicas such @@ -745,65 +748,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest, // Every stage of the pipeline is a single linear layer. XLA_TEST_F(CollectivePipelineParallelismTest, NaiveWoDirectBufferDependencyDFSMicrobatch5CircularRepeat2Replica4) { - const absl::string_view kModuleStr = R"( - HloModule test - - get_circ_buffer_index { - offset = u32[] parameter(0) - index = u32[] parameter(1) - size = u32[] parameter(2) - t0 = u32[] add(offset, index) - t1 = u32[] divide(t0, size) - t2 = u32[] multiply(t1, size) - ROOT t4 = u32[] subtract(t0, t2) - } - - read_buffer { - buffer = f32[5,16] parameter(0) - offset = u32[] parameter(1) - index = u32[] parameter(2) - c0 = u32[] constant(0) - c5 = u32[] constant(5) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c5) - slice = f32[1,16] dynamic-slice(buffer, index__, c0), - dynamic_slice_sizes={1,16} - ROOT slice_ = f32[16] reshape(slice) - } - - update_buffer { - buffer = f32[5,16] parameter(0) - update = f32[16] parameter(1) - offset = u32[] parameter(2) - index = u32[] parameter(3) - c0 = u32[] constant(0) - c5 = u32[] constant(5) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c5) - update_ = f32[1,16] reshape(update) - ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0) - } - - is_input_replica { - replica_id = u32[] replica-id() - c0 = u32[] constant(0) - ROOT predicate = pred[] compare(replica_id, c0), direction=EQ - } - - is_output_replica { - replica_id = u32[] replica-id() - c3 = u32[] constant(3) - ROOT predicate = pred[] compare(replica_id, c3), direction=EQ - } - - is_read_input { - is_input_replica = pred[] call(), to_apply=is_input_replica - i = u32[] parameter(0) - c5 = u32[] constant(5) - is_input_iteration = pred[] compare(i, c5), direction=LT - ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) - } - + constexpr char kMoreComputationsStr[] = R"( while_condition { tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) parameter(0) @@ -819,7 +764,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest, input = f32[5,16] get-tuple-element(tuple), index=1 output = f32[5,16] get-tuple-element(tuple), index=2 buffer = f32[5,16] get-tuple-element(tuple), index=3 - prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4 i = u32[] get-tuple-element(tuple), index=5 c0 = u32[] constant(0) @@ -829,38 +774,36 @@ XLA_TEST_F(CollectivePipelineParallelismTest, c4 = u32[] constant(4) c5 = u32[] constant(5) - input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index - input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), - dynamic_slice_sizes={1,16} - input_slice_ = f32[16] reshape(input_slice) - - buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer - - buffer_ = f32[5,16] call(buffer, prev_iteration_compute_out, c4, i), - to_apply=update_buffer + // Read from buffers before they are updated. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5 + buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5 + // Shift data to the next stage in the pipeline. // Depends on the non-updated buffer of the previous iteration and, // therefore, does not depend on the previous iteration's compute. is_output_replica = pred[] call(), to_apply=is_output_replica next_stage_slice = select(is_output_replica, buffer_slice, - prev_iteration_compute_out) - - + prev_iteration_compute_res) prev_stage_slice = f32[16] collective-permute(next_stage_slice), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} - is_read_input = pred[] call(i), to_apply=is_read_input - compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice) - - compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + // Select compute argument from previous stage or from input and perform + // compute. + is_read_input = pred[] call(i), to_apply=is_read_input_mb5 + compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice) + compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1}, rhs_contracting_dims={0} - output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer + // Update buffers. + buffer_ = f32[5,16] call(buffer, prev_iteration_compute_res, c4, i), + to_apply=update_buffer_mb5 + output_ = f32[5,16] call(output, compute_res, c2, i), + to_apply=update_buffer_mb5 i_ = add(i, c1) ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) - tuple(weights, input, output_, buffer_, compute_out, i_) + tuple(weights, input, output_, buffer_, compute_res, i_) } ENTRY main { @@ -870,11 +813,12 @@ XLA_TEST_F(CollectivePipelineParallelismTest, cf0 = f32[] constant(0) output = f32[5,16] broadcast(cf0), dimensions={} buffer = f32[5,16] broadcast(cf0), dimensions={} - prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} c0 = u32[] constant(0) + // Iterate through pipeline stages. tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) - tuple(weights, input, output, buffer, prev_iteration_compute_out, c0) + tuple(weights, input, output, buffer, prev_iteration_compute_res, c0) tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) while(tuple), condition=while_condition, body=while_body @@ -887,8 +831,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); // This pipeline consists of a total of 8 layers (2 per replica), each of // which is a single linear layer. We assign the weights to the replicas such From f7a7d5d7d658a7c19a66a61a5b6fcc73de5bd995 Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Thu, 1 Aug 2024 18:20:15 +0000 Subject: [PATCH 376/376] Use cuda runtime api to determine if 2 ranks are on the same host or not. --- .../runtime/nccl_collective_permute_thunk.cc | 34 +++++++++++++++++-- .../runtime/nccl_collective_permute_thunk.h | 1 + xla/stream_executor/cuda/cuda_executor.cc | 2 ++ xla/stream_executor/gpu/gpu_executor.h | 2 ++ xla/stream_executor/stream_executor.h | 3 ++ 5 files changed, 40 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc b/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc index 02a8a583d754ef..6c979996505f1c 100644 --- a/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc +++ b/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc @@ -53,6 +53,31 @@ absl::StatusOr GetCurrentId( : current_logical_id.computation_id; return current_id; } + +bool IsLocalPeerTransfer( + const NcclP2PConfig::SourceTargetMapEntry& source_target, + se::Stream& stream, const int64_t current_id, const int64_t device_count) { + const std::optional source_id = source_target.source; + const std::optional target_id = source_target.target; + // Since mixing nccl p2p with p2p memcopy will cause random deadlocks. + // We determine if it's a local peer by the following conditions: + // 1. Both source and target IDs are present and they are within a node + // 2. Source ID is present, but target ID is not. + // 3. Target ID is presetn, but source ID is not. + int64_t host_id = (current_id / device_count); + if (source_id && target_id) { + return (host_id == (*source_id / device_count)) && + (host_id == (*target_id / device_count)); + } + if (source_id) { + return (host_id == (*source_id / device_count)); + } + if (target_id) { + return (host_id == (*target_id / device_count)); + } + return false; +} + } // namespace NcclCollectivePermuteStartThunk::NcclCollectivePermuteStartThunk( @@ -133,6 +158,9 @@ NcclCollectivePermuteStartThunk::NcclCollectivePermuteStartThunk( absl::Status NcclCollectivePermuteStartThunk::Initialize( const InitializeParams& params) { TF_RETURN_IF_ERROR(NcclCollectiveThunk::Initialize(params)); + device_count_ = params.executor->GetDeviceCount(); + VLOG(5) << "Local device count: " << device_count_; + if (p2p_memcpy_enabled_) { TF_ASSIGN_OR_RETURN(const int64_t current_id, GetCurrentId(params.collective_params, config_)); @@ -157,9 +185,11 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective( const NcclP2PConfig::SourceTargetMapEntry source_target = NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id); + bool is_local_peer = + IsLocalPeerTransfer(source_target, stream, current_id, device_count_); + VLOG(5) << "Is local peer : " << (is_local_peer ? "true" : "false"); - bool use_memcpy = comm_wrapper.is_local && - recv_ptr_map_.IsInitialized(current_id) && + bool use_memcpy = is_local_peer && recv_ptr_map_.IsInitialized(current_id) && p2p_memcpy_enabled_; return ::xla::gpu::RunCollectivePermute( diff --git a/xla/service/gpu/runtime/nccl_collective_permute_thunk.h b/xla/service/gpu/runtime/nccl_collective_permute_thunk.h index ab80cf1904f353..9dd2a2998c49c0 100644 --- a/xla/service/gpu/runtime/nccl_collective_permute_thunk.h +++ b/xla/service/gpu/runtime/nccl_collective_permute_thunk.h @@ -111,6 +111,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk { const Buffer buffer_; RecvPtrMap recv_ptr_map_; bool p2p_memcpy_enabled_ = false; + int64_t device_count_; }; absl::Status RunCollectivePermute( diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index 0b90f27b8811d9..e21040f282ab09 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -618,6 +618,8 @@ absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { return GpuDriver::SynchronizeStream(context_, AsGpuStreamValue(stream)); } +int64_t GpuExecutor::GetDeviceCount() { return GpuDriver::GetDeviceCount(); } + blas::BlasSupport* GpuExecutor::AsBlas() { absl::MutexLock lock(&mu_); if (blas_ != nullptr) { diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index f7dd572e918ccd..8635d870a290bf 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -292,6 +292,8 @@ class GpuExecutor : public StreamExecutorCommon { return true; } + int64_t GetDeviceCount() override; + uint64_t GetArgumentLoggingMode() const { return argument_logging_mode_; } // Creates an EventBasedTimer for the given stream. diff --git a/xla/stream_executor/stream_executor.h b/xla/stream_executor/stream_executor.h index f5fb436dfd2274..d9e8f7e9a9597e 100644 --- a/xla/stream_executor/stream_executor.h +++ b/xla/stream_executor/stream_executor.h @@ -324,6 +324,9 @@ class StreamExecutor { // Returns the memory limit in bytes supported by this executor. virtual int64_t GetMemoryLimitBytes() const = 0; + // Returns the total number of compute capable devices. + virtual int64_t GetDeviceCount() { return 0; }; + // The following methods access an internal log of some subset // of arguments passed to other class methods. // Used for testing/debugging purposes.