diff --git a/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc index 212e6b51e5445d..415ed7da7f2bf5 100644 --- a/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc +++ b/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -3650,6 +3651,101 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateSlice) { false, true, error)); } +TEST_F(DynamicSliceFusionTest, TestWithRewriter) { + const char* hlo = R"( + HloModule test_module, replica_count=2 + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a, b) + } + + Body { + param = (s32[], s32[16, 32], s32[8, 32]) parameter(0) + i = s32[] get-tuple-element(param), index=0 + dest = s32[16,32] get-tuple-element(param), index=1 + src = s32[8,32] get-tuple-element(param), index=2 + eight = s32[] constant(8) + zero = s32[] constant(0) + thirty_two = s32[] constant(32) + add = s32[] add(eight, i) + add.2 = s32[] subtract(add, thirty_two) + compare = pred[] compare(add, thirty_two), direction=LT + offset = s32[] select(compare, add, add.2) + rs = s32[4,32] reduce-scatter(src), channel_id=0, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + fusion = s32[16,32] dynamic-update-slice(dest, rs, offset, zero) + one = s32[] constant(1) + i_plus_one = s32[] add(i, one) + ROOT tuple = tuple(i_plus_one, fusion, src) + } + + Cond { + param = (s32[], s32[16,32], s32[8,32]) parameter(0) + loop_iter = s32[] get-tuple-element(param), index=0 + c32 = s32[] constant(32) + ROOT compare = pred[] compare(loop_iter, c32), direction=LT + } + + ENTRY main { + zero = s32[] constant(0) + dest = s32[16,32] parameter(0) + src = s32[8,32] parameter(1) + tuple = tuple(zero, dest, src) + ROOT while = while(tuple), body=Body, condition=Cond + } + )"; + + HloModuleConfig config; + DebugOptions dboptions; + dboptions.set_xla_gpu_enable_dynamic_slice_fusion(false); + config.set_debug_options(dboptions); + TF_ASSERT_OK_AND_ASSIGN(auto module0, + ParseAndReturnVerifiedModule(hlo, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto module_without_fusion, + GetOptimizedModule(std::move(module0))); + dboptions.set_xla_gpu_enable_dynamic_slice_fusion(true); + config.set_debug_options(dboptions); + TF_ASSERT_OK_AND_ASSIGN(auto module1, + ParseAndReturnVerifiedModule(hlo, config)); + TF_ASSERT_OK_AND_ASSIGN(auto module_with_fusion, + GetOptimizedModule(std::move(module1))); + + ASSERT_EQ(GetDynamicSliceFusions(*module_without_fusion).size(), 0); + auto fusions = GetDynamicSliceFusions(*module_with_fusion); + ASSERT_EQ(fusions.size(), 1); + HloPrintOptions options; + options.set_print_large_constants(true) + .set_print_result_shape(false) + .set_print_operand_shape(false); + TF_ASSERT_OK_AND_ASSIGN(auto filecheck_fusion, + RunFileCheck(fusions[0]->ToString(options), + R"( + // CHECK-DAG: %[[rs:.+]] = reduce-scatter({{.+}}) + // CHECK-DAG: %[[offset_vals:.+]] = constant({8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7}) + // CHECK-DAG: %[[offset_as_arr:.+]] = dynamic-slice(%[[offset_vals]], {{.+}}), dynamic_slice_sizes={1} + // CHECK-DAG: %[[offset:.+]] = reshape(%[[offset_as_arr]]) + // CHECK-DAG: ROOT %{{.+}} = dynamic-update-slice({{.+}}, %[[rs]], %[[offset]], {{.+}}) + )")); + EXPECT_TRUE(filecheck_fusion); + TF_ASSERT_OK_AND_ASSIGN( + auto filecheck_while_loop, + RunFileCheck(fusions[0]->FusionInstruction()->parent()->ToString(options), + R"( + // CHECK-DAG: %[[p:.+]] = parameter(0) + // CHECK-DAG: %[[loop_counter:.+]] = get-tuple-element(%[[p]]), index=3 + // CHECK-DAG: %[[address_computation:.+]] = fusion({{.+}}, %[[loop_counter]]), kind=kCustom + // CHECK-DAG: %[[updated_loop_counter:.+]] = add(%[[loop_counter]], {{.+}}) + // CHECK-DAG: ROOT {{.+}} = tuple({{.+}}, %[[address_computation]], {{.+}}, %[[updated_loop_counter]]) + )")); + EXPECT_TRUE(filecheck_while_loop); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + EXPECT_TRUE(RunAndCompareTwoModulesReplicated( + std::move(module_without_fusion), std::move(module_with_fusion), false, + true, error_spec)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index 474c8423b9fedb..f99089b6c0a488 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1423,19 +1423,23 @@ cc_library( hdrs = ["dynamic_slice_fusion_rewriter.h"], tags = ["gpu"], deps = [ + "//xla:literal_util", "//xla:shape_util", "//xla:util", "//xla/ffi:ffi_api", + "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:custom_call_target_registry", "//xla/service:pattern_matcher", + "//xla/service:while_loop_analysis", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:gpu_constants", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu/kernels:custom_fusion_library", + "//xla/tools:hlo_extractor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1473,6 +1477,7 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", diff --git a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc index 8ea3bc3801062e..a58bed9d75697b 100644 --- a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc +++ b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -33,11 +33,14 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/literal_util.h" +#include "xla/primitive_util.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" @@ -45,8 +48,10 @@ limitations under the License. #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/pattern_matcher.h" +#include "xla/service/while_loop_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tools/hlo_extractor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -77,6 +82,9 @@ using DataflowPathsView = absl::Span; using InstructionSet = absl::flat_hash_set; +using OffsetValueMap = + absl::flat_hash_map>; + bool IsNoOp(const HloInstruction* hlo) { return HloPredicateIsOp(hlo); @@ -152,99 +160,424 @@ bool IsAlignedSlice(const HloInstruction* slice) { return true; } -// Pattern matches the following IR (generated by `jax.lax.scan`) to check if -// the offset is a loop iteration number: - -// clang-format off -// param = (s32[], s32[], s32[16]{0}, s32[16]{0}) parameter(0) -// // the index in `gte` has to be the loop iteration index -// gte = s32[] get-tuple-element(param), index=0 -// c0 = s32[] constant(0) compare = pred[] compare(gte, c0), direction=LT -// c_trip_count = s32[] constant(16) -// add = s32[] add(gte, c_trip_count) select = s32[] select(compare, add, gte) -// clang-format on - -bool IsLoopIterationNumber(const HloInstruction& offset) { - const HloComputation* parent = offset.parent(); - if (!parent->IsWhileBodyComputation()) return false; - - // Scan loops trip count must be known at compile time as it iterates over the - // leading dimension of the statically shaped input. - const HloInstruction* while_instr = parent->WhileCallInstruction(); - auto config = while_instr->backend_config(); - if (!config.ok() || !config->has_known_trip_count()) return false; - int32_t trip_count = config->known_trip_count().n(); - - // First lets check the offset computation pattern - if (!Match(&offset, m::Select(m::Lt(m::GetTupleElement(m::Parameter(0)), - m::ConstantScalar(0)), - m::Add(m::GetTupleElement(m::Parameter(0)), - m::ConstantScalar(trip_count)), - m::GetTupleElement(m::Parameter())))) { - return false; +// Function looks for while backend config. If this config is present, it +// returns the value of trip count, otherwise it runs the while loop analysis to +// compute trip count. `whileop` must be a while operaton. Returns +// `std::nullopt` if it cannot figure out the trip count. +std::optional GetWhileLoopTripCount(HloInstruction* whileop) { + CHECK(whileop->opcode() == HloOpcode::kWhile); + auto backend_config = whileop->backend_config(); + if (!backend_config.ok() || !backend_config.value().has_known_trip_count()) { + VLOG(4) << "Backend config not ok. Computing while loop trip count for " + << whileop->name(); + return ComputeWhileLoopTripCount(whileop); } + int trip_count = backend_config.value().known_trip_count().n(); + VLOG(4) << "Found trip count in backend config for " << whileop->name() + << ": " << trip_count; + return trip_count; +} - // Next, we check that the parameter used in offset computation is the loop - // induction variable - int64_t param_idx = offset.operand(2)->tuple_index(); - const HloInstruction* root = offset.parent()->root_instruction(); - if (root->opcode() != HloOpcode::kTuple) { - return false; +// Given an HLO operation `idx`, which is wrapped by while operation, this +// function tries to find the values of the variable in all the iterations as an +// array of literals. This is done by repeatedly executing the loop update +// operation(s) and the operation(s) to calculate the value of `idx` at each +// iteration. If this is successful, then the vector of literals is returned. If +// for some reason this is not successful then `std::nullopt` is returned. +std::optional> GetValues(const HloInstruction* idx) { + VLOG(3) << "Getting values for " << idx->name(); + const HloComputation* computation = idx->parent(); + if (!computation->IsWhileBodyComputation()) { + VLOG(3) << "While calculating offset values for " << idx->name() + << ", the parent computation(" << computation->name() + << ") is not a while computation"; + return std::nullopt; } - // Check the update operation - const HloInstruction* updated_var = - offset.parent()->root_instruction()->operand(param_idx); - if (!Match(updated_var, m::Add(m::GetTupleElement(m::Parameter(0), param_idx), - m::ConstantScalar(1)))) { - return false; + HloInstruction* whileop = computation->WhileCallInstruction(); + std::optional trip_count = GetWhileLoopTripCount(whileop); + if (trip_count == std::nullopt) { + VLOG(3) << "Unable to get trip count for " << whileop->name(); + return std::nullopt; } - // Check that the condition considers this. - const HloInstruction* condition_root = - while_instr->while_condition()->root_instruction(); - if (!Match(condition_root, - m::Lt(m::GetTupleElement(m::Parameter(0), param_idx), - m::ConstantScalar(trip_count)))) { - return false; + auto root_tuple = computation->root_instruction(); + if (root_tuple->opcode() != HloOpcode::kTuple) { + VLOG(3) << "Root operation " << root_tuple->name() << " of computation " + << computation->name() + << " expected to be a tuple because it is a while body. Found: " + << root_tuple->opcode(); + return std::nullopt; } - // Check init - const HloInstruction* init_loop_iter = - while_instr->operand(0)->operand(param_idx); - if (!Match(init_loop_iter, m::ConstantScalar(0))) { - return false; + std::optional loop_indvar_tuple_idx = + GetLoopInductionVarTupleIdx(whileop); + if (loop_indvar_tuple_idx == std::nullopt) { + VLOG(3) << "Unable to find tuple index for loop induction variable"; + return std::nullopt; + } + auto update_operation = + computation->root_instruction()->operand(*loop_indvar_tuple_idx); + HloInstruction* loop_indvar = nullptr; + for (auto instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->operand(0) == computation->parameter_instruction(0) && + instr->tuple_index() == *loop_indvar_tuple_idx) { + loop_indvar = instr; + } + } + if (loop_indvar == nullptr) { + VLOG(3) << "Unable to find get-tuple-element(" + << computation->parameter_instruction(0)->name() + << "), index=" << *loop_indvar_tuple_idx << " in " + << computation->name(); + return std::nullopt; } - return true; + // Extract the offset and update modules and verify that they only take the + // loop iteration counter as parameter. + // The operation we are extracting (update and offset) are from `computation`. + // In the `extract_selector`, we stop at the parameter (tuple) for this + // `computation` or at the loop induction variable and convert that to a + // parameter. If the operation depends on the tuple parameter, then the + // argument to the extracted module will have the shape of a tuple. So, if the + // extracted module has only one parameter and the shape of that parameter is + // same as the loop induction variable, then the operation only depends on the + // loop induction variable. We also have to ensure there are no `partition-id` + // or `replica-id` operations in the extracted module. + auto IsValidModule = + [loop_indvar](std::unique_ptr& module) -> bool { + if (module == nullptr || module->entry_computation()->num_parameters() != 1) + return false; + const HloInstruction* p0 = + module->entry_computation()->parameter_instruction(0); + if (p0->shape() != loop_indvar->shape()) { + VLOG(4) << "Extracted module must depend only on the loop induction " + "variable."; + return false; + }; + return llvm::all_of(module->entry_computation()->instructions(), + [](const HloInstruction* instr) { + return instr->opcode() != HloOpcode::kPartitionId && + instr->opcode() != HloOpcode::kReplicaId; + }); + }; + auto params = computation->parameter_instructions(); + if (params.size() != 1 || !params[0]->shape().IsTuple()) { + VLOG(3) << "While loop parameter is expected to be a tuple."; + return std::nullopt; + } + std::unique_ptr offset_module = ExtractModule( + /*instruction=*/ + idx, /*height=*/-1, + /*extract_selector=*/ + [loop_indvar, params](const HloInstruction* inst) -> bool { + return inst != loop_indvar && llvm::find(params, inst) == params.end(); + }, + /*replace_type_selector=*/ + [](const HloInstruction* inst) -> ReplaceType { + return ReplaceType::kReplaceParam; + }); + std::unique_ptr update_module = ExtractModule( + /*instruction=*/ + update_operation, /*height=*/-1, + /*extract_selector=*/ + [loop_indvar, params](const HloInstruction* inst) -> bool { + return inst != loop_indvar && llvm::find(params, inst) == params.end(); + }, + /*replace_type_selector=*/ + [](const HloInstruction* inst) -> ReplaceType { + return ReplaceType::kReplaceParam; + }); + if (!IsValidModule(offset_module) || !IsValidModule(update_module)) { + return std::nullopt; + } + VLOG(3) << "Successfully generated offset and update modules"; + + std::vector offset_values; + absl::Status status = [&]() -> absl::Status { + HloEvaluator evaluator; + const Literal& init = + whileop->operand(0)->operand(*loop_indvar_tuple_idx)->literal(); + std::unique_ptr updated_value = nullptr; + for (int64_t i = 0; i < *trip_count; i++) { + if (i == 0) { + evaluator.ResetVisitStates(); + TF_ASSIGN_OR_RETURN(offset_values.emplace_back(), + evaluator.Evaluate(*offset_module, {&init})); + CHECK(offset_values.back().shape() == idx->shape()); + evaluator.ResetVisitStates(); + TF_ASSIGN_OR_RETURN(Literal next_update_value, + evaluator.Evaluate(*update_module, {&init})); + updated_value = next_update_value.CloneToUnique(); + } else { + evaluator.ResetVisitStates(); + TF_ASSIGN_OR_RETURN( + offset_values.emplace_back(), + evaluator.Evaluate(*offset_module, {updated_value.get()})); + CHECK(offset_values.back().shape() == idx->shape()); + evaluator.ResetVisitStates(); + TF_ASSIGN_OR_RETURN( + Literal next_update_value, + evaluator.Evaluate(*update_module, {updated_value.get()})); + updated_value = next_update_value.CloneToUnique(); + } + } + VLOG(3) << "Offset values for " << idx->name() << ": " + << absl::StrJoin(offset_values, ",", + [](std::string* out, const Literal& l) { + out->append(l.ToString()); + }); + return absl::OkStatus(); + }(); + if (status.ok()) return offset_values; + return std::nullopt; } -// This returns true for the constants that are handled in the dynamic slice -// fusion runtime. These constants do not force a D2H copy and hence preserve -// the cuda graph. -bool IsHandledConstantForDynamicSliceFusion(const HloInstruction& offset) { - if (auto* cst = DynCast(&offset)) { - switch (cst->shape().element_type()) { - case PrimitiveType::S32: - case PrimitiveType::S64: - case PrimitiveType::U32: - case PrimitiveType::U64: - return true; - default: +// This function takes a while operation and adds a loop iteration counter +// variable as the last parameter in the loop. This is useful, especially +// because the loop induction variable might not be 0,1,2,3... and we need a +// variable of this form to access the array literal for offset. +absl::StatusOr AddLoopIterationParam(HloInstruction* whileop) { + CHECK(whileop->opcode() == HloOpcode::kWhile); + HloComputation* while_body = whileop->while_body(); + HloComputation* while_cond = whileop->while_condition(); + const HloInstruction* while_init = whileop->operand(0); + + // First handle the initial values. + CHECK(while_init->opcode() == HloOpcode::kTuple); + std::vector new_init_operands(while_init->operands().begin(), + while_init->operands().end()); + PrimitiveType indvar_type = + whileop->while_init() + ->operand(*GetLoopInductionVarTupleIdx(whileop)) + ->shape() + .element_type(); + new_init_operands.push_back(whileop->parent()->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + whileop->while_init() + ->operand(*GetLoopInductionVarTupleIdx(whileop)) + ->shape() + .element_type(), + 0)), + "zero")); + HloInstruction* new_while_init = whileop->parent()->AddInstruction( + HloInstruction::CreateTuple(new_init_operands)); + HloInstruction* new_whileop = whileop->parent()->AddInstruction( + whileop->CloneWithNewOperands(new_while_init->shape(), {new_while_init})); + if (whileop->IsRoot()) { + absl::InlinedVector tuple_entries; + tuple_entries.reserve(while_init->shape().tuple_shapes_size()); + for (auto i = 0; i < while_init->shape().tuple_shapes_size(); i++) { + tuple_entries.push_back(whileop->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(new_whileop, i))); + } + HloInstruction* new_whileop_result = whileop->parent()->AddInstruction( + HloInstruction::CreateTuple(tuple_entries)); + TF_RETURN_IF_ERROR( + whileop->parent()->ReplaceInstruction(whileop, new_whileop_result)); + } else { + TF_RETURN_IF_ERROR(whileop->parent()->ReplaceInstructionWithDifferentShape( + whileop, new_whileop)); + } + + // Next, lets handle the condition + while_cond->ReplaceParameter(0, HloInstruction::CreateParameter( + 0, new_while_init->shape(), "new_param")); + + // Next, lets handle the body + HloInstruction* new_body_param = while_body->ReplaceParameter( + 0, + HloInstruction::CreateParameter(0, new_while_init->shape(), "new_param")); + + // Next, update the value of the param inside while op + HloInstruction* gte = while_body->AddInstruction( + HloInstruction::CreateGetTupleElement( + new_body_param, new_while_init->shape().tuple_shapes_size() - 1), + "loop_iteration_count"); + HloInstruction* c1 = while_body->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(indvar_type, 1)), + "one"); + HloInstruction* add = while_body->AddInstruction( + HloInstruction::CreateBinary(gte->shape(), HloOpcode::kAdd, gte, c1), + "updated_loop_iteration_count"); + absl::InlinedVector old_return_tuple_operands = + while_body->root_instruction()->operands(); + std::vector new_return_tuple_operands( + old_return_tuple_operands.begin(), old_return_tuple_operands.end()); + new_return_tuple_operands.push_back(add); + HloInstruction* new_return_tuple = while_body->AddInstruction( + HloInstruction::CreateTuple(new_return_tuple_operands)); + while_body->set_root_instruction(new_return_tuple, true); + return gte; +} + +// This function takes an array literal and gives a constant instruction with +// that literal. +std::unique_ptr GetAsConstantInstruction( + const std::vector& offset_values) { + if (offset_values.empty()) return nullptr; + std::unique_ptr value = + primitive_util::PrimitiveTypeSwitch>( + [&offset_values]( + auto primitive_type_constant) -> std::unique_ptr { + if constexpr (primitive_util::IsIntegralType( + primitive_type_constant)) { + using NativeT = typename primitive_util::PrimitiveTypeToNative< + primitive_type_constant>::type; + + Array constantLiterals({(int64_t)offset_values.size()}); + std::vector valuesAsTy; + valuesAsTy.reserve(offset_values.size()); + for (auto& i : offset_values) { + valuesAsTy.push_back( + static_cast(i.data()[0])); + } + constantLiterals.SetValues(valuesAsTy); + return HloInstruction::CreateConstant( + LiteralUtil::CreateFromArray(constantLiterals)); + } + return nullptr; + }, + offset_values[0].shape().element_type()); + return value; +} + +// This function takes an operation, and a reference to a map of +// {operation: array literals containing their values}. If the operation is a +// dynamic slicing operation, we populate the value map with the values of the +// offsets. This only returns true if it can successfully find values +// corresponding to all the offsets in the `matched_instrs`. If there is a +// single offset for which we cannot find the values, then we do not add +// anything to the value map, and return false. +bool PopulateOffsetValueMap(const HloInstruction* matched_instr, + OffsetValueMap& value_map) { + OffsetValueMap local_value_map; + if (auto dyn_idx_op = DynCast(matched_instr); + dyn_idx_op) { + for (auto indexop : dyn_idx_op->index_operands()) { + if (indexop->IsConstant()) continue; + if (local_value_map.contains(indexop) || value_map.contains(indexop)) + continue; + std::optional> values = GetValues(indexop); + if (values == std::nullopt) return false; + if (values->empty() || !primitive_util::IsIntegralType( + values->at(0).shape().element_type())) { return false; - }; + } + std::transform(values->begin(), values->end(), + std::back_inserter(local_value_map[indexop]), + [](Literal& l) { return std::move(l); }); + } + } + for (auto& [op, values] : local_value_map) { + std::transform(values.begin(), values.end(), + std::back_inserter(value_map[op]), + [](Literal& l) { return std::move(l); }); } - return false; + VLOG(2) << "Received " << local_value_map.size() << " new offsets."; + return true; } -// This checks whether a dynamic index operation has all offsets that are either -// constant or loop iteration offsets. -bool HasConstantOrLoopIterationOffsets( - const HloDynamicIndexInstruction& instr) { - return llvm::all_of(instr.index_operands(), [](const HloInstruction* offset) { - return IsLoopIterationNumber(*offset) || - IsHandledConstantForDynamicSliceFusion(*offset); - }); +// This function takes a list of fusion instructions, and a value map +// {operation: array literal containing its values across iterations}. These +// fusions take the value of offset as a input. So, the value of this offset is +// calculated outside the fusion. This function changes these fusions so that +// the fusion instead only takes the loop iteration number and the offset is +// read from a constant array. This constant array comes from the value map. On +// a high level, the transform looks like: +// +// clang-format off +// +// input-fusion(p0, p1, p2, offset, c0) { +// ds = dynamic-slice(p0, offset, c0, c0) +// gemm = custom-call(ds, p1) +// ROOT dus = dynamic-update-slice(p2, gemm, offset, c0, c0) +// } +// +// changes to +// +// output-fusion(p0, p1, p2, loop_counter, c0) { +// offset_values = constant({2,4,6,8,10}) +// offset_array = dynamic-slice(offset_values, loop_counter), slice_size={1} +// offset = reshape(offset_array) +// ds = dynamic-slice(p0, offset, c0, c0) +// gemm = custom-call(ds, p1) +// ROOT dus = dynamic-update-slice(p2, gemm, offset, c0, c0) +// } +// +// clang-format on +absl::Status ReplaceOffsetCalculationWithArrayAccess( + PtrVec fusions, OffsetValueMap& value_map) { + absl::flat_hash_map loop_iteration_param; + for (auto& [instr, _] : value_map) { + VLOG(2) << "Handling " << instr->name(); + if (!instr->parent()->IsWhileBodyComputation()) { + VLOG(2) << "It is not a while body computation"; + return absl::InternalError( + absl::StrFormat("%s is expected to be a while computation.", + instr->parent()->name())); + } + if (loop_iteration_param.find(instr->parent()) != + loop_iteration_param.end()) { + VLOG(2) << "This was already handled"; + continue; + } + VLOG(2) << "Adding loop iteration param for " << instr->parent()->name(); + TF_ASSIGN_OR_RETURN( + loop_iteration_param[instr->parent()], + AddLoopIterationParam(instr->parent()->WhileCallInstruction())); + } + for (auto fusion_instr : fusions) { + // Check that this fusion operation has something we need to replace: + for (auto maybe_offset : fusion_instr->operands()) { + if (value_map.find(maybe_offset) == value_map.end()) continue; + HloInstruction* loop_counter = + loop_iteration_param[fusion_instr->parent()]; + HloComputation* fusion = fusion_instr->fused_instructions_computation(); + loop_iteration_param[fusion] = + fusion_instr->AddFusionOperand(loop_counter); + break; + } + } + for (auto fusion_instr : fusions) { + absl::flat_hash_map param_replacement_map; + absl::InlinedVector parameters; + HloComputation* fusion_comp = + fusion_instr->fused_instructions_computation(); + for (auto [idx, maybe_offset] : llvm::enumerate(fusion_instr->operands())) { + HloInstruction* offset_param = + fusion_instr->fused_instructions_computation()->parameter_instruction( + idx); + if (value_map.find(maybe_offset) == value_map.end() || + param_replacement_map.contains(offset_param)) + continue; + std::vector& values = value_map.at(maybe_offset); + std::unique_ptr values_as_const_instruction = + GetAsConstantInstruction(values); + if (values_as_const_instruction == nullptr) { + return absl::InternalError( + "Unable to convert offsets into constant array."); + } + HloInstruction* array = fusion_comp->AddInstruction( + std::move(values_as_const_instruction), "offset_values"); + HloInstruction* ds = + fusion_comp->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(offset_param->shape().element_type(), {1}), + array, {loop_iteration_param[fusion_comp]}, {1})); + HloInstruction* offset = fusion_comp->AddInstruction( + HloInstruction::CreateReshape(offset_param->shape(), ds), "offset"); + param_replacement_map[offset_param] = offset; + parameters.push_back(offset_param); + } + for (auto param = parameters.rbegin(); param != parameters.rend(); + param++) { + auto offset = param_replacement_map[*param]; + TF_RETURN_IF_ERROR(fusion_comp->ReplaceInstruction(*param, offset)); + } + } + return absl::OkStatus(); } -UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { +UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr, + OffsetValueMap& value_map) { UseDefDataflowPaths sliced_operand_paths; // This set is used to avoid duplicates in the matched results. It contains @@ -292,14 +625,9 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { }); if (maybe_slice_instr == std::nullopt) continue; - auto dynamic_index_operation = - DynCast(maybe_slice_instr.value()); - bool valid_slice_found = - slice_found && - ((dynamic_index_operation && - HasConstantOrLoopIterationOffsets(*dynamic_index_operation)) || - (*maybe_slice_instr)->opcode() == HloOpcode::kSlice); - if (valid_slice_found || + bool valid_slice_status = + PopulateOffsetValueMap(*maybe_slice_instr, value_map); + if ((valid_slice_status && slice_found) || processed_instrs.contains(maybe_slice_instr.value())) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced operand path @@ -320,7 +648,8 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { // vector. // Each entry contains the sliced paths for that user, i.e. the sequence of ops // following the dataflow from the user itself to the DUS (included). -DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { +DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr, + OffsetValueMap& value_map) { DefUseDataflowPaths sliced_user_paths; // This set is used to avoid duplicates in the matched results. It contains // the matched instructions that we have seen so far. @@ -347,12 +676,10 @@ DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { }, /*visit_operands=*/false); if (maybe_dus_instr == std::nullopt) return; - auto dynamic_index_operation = - DynCast(maybe_dus_instr.value()); - bool valid_dus_found = - dus_found && dynamic_index_operation && - HasConstantOrLoopIterationOffsets(*dynamic_index_operation); - if (valid_dus_found || processed_instrs.contains(maybe_dus_instr.value())) { + bool valid_slice_status = + PopulateOffsetValueMap(*maybe_dus_instr, value_map); + if ((valid_slice_status && dus_found) || + processed_instrs.contains(maybe_dus_instr.value())) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced user path // during the latest traversal. @@ -519,6 +846,8 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( matches_kv; std::vector matches; + OffsetValueMap value_map; + // Collect all potential custom call matches in the non-fusion computations. for (HloComputation* computation : module->computations()) { if (computation->IsFusionComputation()) continue; @@ -526,9 +855,30 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( if ((instr->opcode() == HloOpcode::kReduceScatter && instr->shape().IsArray()) || IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) { - UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr); + UseDefDataflowPaths sliced_operand_paths = + GetSlicedOperandPaths(instr, value_map); + VLOG(1) << "For operation: " << instr->name() << ", operands: " + << absl::StrJoin( + sliced_operand_paths, ",", + [](std::string* out, const HloInstruction* inst) { + out->append(inst->name()); + }); bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; - DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); + DefUseDataflowPaths sliced_user_paths = + GetSlicedUserPaths(instr, value_map); + VLOG(1) << "For operation: " << instr->name() << ", users: " + << absl::StrJoin( + sliced_user_paths, ",", + [](std::string* out, const DefUseDataflowPath& path) { + out->append( + "{" + + absl::StrJoin(path, ",", + [](std::string* out, + const HloInstruction* inst) { + out->append(inst->name()); + }) + + "}"); + }); bool has_sliced_user_paths = absl::c_any_of( sliced_user_paths, [&](auto& sliced_user_path) { return !sliced_user_path.empty(); }); @@ -552,6 +902,8 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( if (matches.empty()) return false; + PtrVec fusions; + for (HloInstruction* hero : matches) { auto& paths = matches_kv[hero]; auto& [sliced_operand_paths, sliced_user_paths] = paths; @@ -580,7 +932,7 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( HloInstruction * fusion, CreateFusionInstruction(module, hero, captures, fusion_body, has_dynamic_slices)); - + fusions.push_back(fusion); HloComputation* parent = hero->parent(); if (fusion->shape().IsTuple()) { TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape( @@ -624,6 +976,9 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( } } + TF_RETURN_IF_ERROR( + ReplaceOffsetCalculationWithArrayAccess(fusions, value_map)); + return true; } diff --git a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc index 9a71c9930adc78..9fb49afb1847b2 100644 --- a/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc +++ b/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/stream.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -1857,12 +1858,15 @@ TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSLoopIterationOffset) { })"; const char* expected = R"( // CHECK: %dynamic-slice-fusion{{.*}}{ - // CHECK: {{.+}} = {{.*}}reduce-scatter({{.+}}) - // CHECK: {{.+}} = {{.*}}dynamic-update-slice({{.+}}) + // CHECK: {{.+}} = {{.*}} reduce-scatter({{.+}}) + // CHECK: {{.+}} = s32[128]{0} constant({{.+}}) + // CHECK: {{.+}} = {{.+}} dynamic-slice({{.+}}) + // CHECK: {{.+}} = {{.+}} reshape({{.+}}) + // CHECK: {{.+}} = {{.*}} dynamic-update-slice({{.+}}) // CHECK: } // CHECK: Body{{.+}}{ - // CHECK-NOT: {{.+}} = {{.*}}reduce-scatter({{.+}}) - // CHECK: {{.+}} = {{.+}}fusion({{.+}}), kind=kCustom, calls=%dynamic-slice-fusion{{.*}}"name":"dynamic_address_computation" + // CHECK-NOT: {{.+}} = {{.*}} reduce-scatter({{.+}}) + // CHECK: {{.+}} = {{.+}} fusion({{.+}}), kind=kCustom, calls=%dynamic-slice-fusion{{.*}}"name":"dynamic_address_computation" // CHECK: } )"; RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); @@ -1881,26 +1885,10 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) { bitcast.41 = f16[8,8]{1,0} bitcast(p0) bitcast.42 = f16[8,8]{1,0} bitcast(p1) - custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{ - "alpha_real":1, - "beta":0, - "dot_dimension_numbers":{ - "lhs_contracting_dimensions":["1"], - "rhs_contracting_dimensions":["0"], - "lhs_batch_dimensions":[], - "rhs_batch_dimensions":[] - }, - "alpha_imag":0, - "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, - "epilogue":"DEFAULT", - "lhs_stride":"64", - "rhs_stride":"64", - "grad_x":false, - "grad_y":false - }} + custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), custom_call_target="__cublas$gemm" bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1) c0 = u32[] constant(0) - c_trip_count = u32[] constant(11) + c_trip_count = u32[] constant(8) compare = pred[] compare(loop_iter, c0), direction=LT add = u32[] add(loop_iter, c_trip_count) offset = u32[] select(compare, add, loop_iter) @@ -1913,7 +1901,7 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) { %Cond { %param.1 = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0) %i.1 = u32[] get-tuple-element(%param.1), index=3 - %trip_count = u32[] constant(11) + %trip_count = u32[] constant(8) ROOT %done = pred[] compare(u32[] %i.1, u32[] %trip_count), direction=LT } @@ -1923,19 +1911,32 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) { %p2.1 = f16[4,8,8]{2,1,0} parameter(2) %c0.1 = u32[] constant(0) %initial_tuple = tuple(%p0.1, %p1.1, %p2.1, u32[] %c0.1) - ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"11"}} + ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"8"}} })"; const char* expected = R"( + // CHECK: %dynamic-slice-fusion{{.*}} { + // CHECK-DAG: %[[p0:.+]] = f16[8,8]{1,0} parameter(0) + // CHECK-DAG: %[[p1:.+]] = f16[8,8]{1,0} parameter(1) + // CHECK-DAG: %[[p2:.+]] = f16[4,8,8]{2,1,0} parameter(2) + // CHECK-DAG: %[[gemm:.+]] = f16[8,8]{1,0} custom-call(%[[p0]], %[[p1]]), custom_call_target="__cublas$gemm" + // CHECK-DAG: %[[bc_gemm:.+]] = f16[1,8,8]{2,1,0} bitcast(%[[gemm]]) + // CHECK-DAG: %[[offset_values:.+]] = u32[8]{0} constant({0, 1, 2, 3, 4, 5, 6, 7}) + // CHECK-DAG: %[[p4:.+]] = u32[] parameter(4) + // CHECK-DAG: %[[offset_as_array:.+]] = u32[1]{0} dynamic-slice(%[[offset_values]], %[[p4]]), dynamic_slice_sizes={1} + // CHECK-DAG: %[[offset:.+]] = u32[] reshape(%[[offset_as_array]]) + // CHECK-DAG: %[[p3:.+]] = u32[] parameter(3) + // CHECK-DAG: ROOT %{{.+}} = f16[4,8,8]{2,1,0} dynamic-update-slice(%[[p2]], %[[bc_gemm]], %[[offset]], %[[p3]], %[[p3]]) + // CHECK: } // CHECK: %Body{{.+}}{ // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0) - // CHECK: %[[LOOP_ITER:.+]] = u32[] get-tuple-element(%[[PARAM]]), index=3 - // CHECK: %[[OFFSET:.+]] = u32[] select({{.+}}) - // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, {{.+}}, {{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%dynamic-slice-fusion, {{.+}}"name":"dynamic_address_computation" - // CHECK: ROOT %tuple = {{.+}} tuple(%{{.+}}, %{{.+}}, %[[ADDRESS_COMPUTATION]], %{{.+}}) + // CHECK: %[[LOOP_ITER:.+]] = u32[] get-tuple-element(%[[PARAM]]), index=4 + // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion(%{{.+}}, %loop_iteration_count), kind=kCustom, calls=%dynamic-slice-fusion + // CHECK: ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %{{.+}}, %[[ADDRESS_COMPUTATION]], %{{.+}}) // CHECK: } // CHECK: ENTRY %test{{.+}}{ - // CHECK: ROOT %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"11"}} + // CHECK: %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"8"}} + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}}) } )"; @@ -1984,6 +1985,86 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmParameterOffset) { std::nullopt); } +TEST_F(DynamicSliceFusionRewriterTest, DUSOffsetAsFunctionOfLoopIteration) { + const char* hlo = R"( + HloModule test_module, replica_count=2 + + add { + a = s64[] parameter(0) + b = s64[] parameter(1) + ROOT add = s64[] add(a, b) + } + + Body { + param = (s64[], s64[16, 32], s64[8, 32]) parameter(0) + i = s64[] get-tuple-element(param), index=0 + dest = s64[16,32] get-tuple-element(param), index=1 + src = s64[8,32] get-tuple-element(param), index=2 + eight = s64[] constant(8) + zero = s64[] constant(0) + thirty_two = s64[] constant(32) + add = s64[] add(eight, i) + add.2 = s64[] subtract(add, thirty_two) + compare = pred[] compare(add, thirty_two), direction=LT + offset = s64[] select(compare, add, add.2) + rs = s64[4,32] reduce-scatter(src), channel_id=1, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + fusion = s64[16,32] dynamic-update-slice(dest, rs, offset, zero) + one = s64[] constant(1) + i_plus_one = s64[] add(i, one) + ROOT tuple = tuple(i_plus_one, fusion, src) + } + + Cond { + param = (s64[], s64[16,32], s64[8,32]) parameter(0) + loop_iter = s64[] get-tuple-element(param), index=0 + c16 = s64[] constant(16) + ROOT compare = pred[] compare(loop_iter, c16), direction=LT + } + + ENTRY main { + zero = s64[] constant(0) + dest = s64[16,32] parameter(0) + src = s64[8,32] parameter(1) + tuple = tuple(zero, dest, src) + ROOT while = while(tuple), body=Body, condition=Cond + } + )"; + + const char* expected = R"( + // CHECK: %dynamic-slice-fusion{{.*}} { + // CHECK-DAG: %[[p1:.*]] = s64[16,32]{1,0} parameter(1) + // CHECK-DAG: %[[p0:.*]] = s64[8,32]{1,0} parameter(0) + // CHECK-DAG: %[[rs:.+]] = s64[4,32]{1,0} reduce-scatter(s64[8,32]{1,0} %[[p0]]), channel_id=1 + // CHECK-DAG: %[[offset_values:.+]] = s64[16]{0} constant({8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}) + // CHECK-DAG: %[[p3:.+]] = s64[] parameter(3) + // CHECK-DAG: %[[ds:.+]] = s64[1]{0} dynamic-slice(s64[16]{0} %[[offset_values]], s64[] %[[p3]]), dynamic_slice_sizes={1} + // CHECK-DAG: %[[offset:.+]] = s64[] reshape(s64[1]{0} %[[ds]]) + // CHECK-DAG: %[[p2:.+]] = s64[] parameter(2) + // CHECK-DAG: ROOT %{{.+}} = s64[16,32]{1,0} dynamic-update-slice(s64[16,32]{1,0} %[[p1:.*]], s64[4,32]{1,0} %[[rs]], s64[] %[[offset]], s64[] %[[p2]]) + // CHECK: } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN( + auto changed, + RunHloPass(DynamicSliceFusionRewriter("gpu"), module.get())); + EXPECT_TRUE(changed); + std::vector fusions; + for (auto computation : module->computations()) { + if (computation->IsFusionComputation()) { + fusions.push_back(computation); + } + } + ASSERT_EQ(fusions.size(), 1); + const HloComputation* dynamic_slice_fusion = fusions[0]; + TF_ASSERT_OK_AND_ASSIGN( + auto filecheck_match, + RunFileCheck(dynamic_slice_fusion->ToString( + HloPrintOptions{}.set_print_large_constants(true)), + expected)); + EXPECT_TRUE(filecheck_match); +} + TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLaxScan) { const char* hlo = R"( HloModule lax_scan @@ -1995,68 +2076,71 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLaxScan) { // ans = jax.lax.scan(lambda carry, x : (init, x@carry), init, inp) Body { - arg_tuple.15 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + arg_tuple.15 = (s32[], f32[8,8]{1,0}, f32[8,8,8]{2,1,0}, f32[8,8,8]{2,1,0}, f32[8,8]{1,0}) parameter(0) get-tuple-element.16 = s32[] get-tuple-element(arg_tuple.15), index=0 constant.21 = s32[] constant(1) add.2 = s32[] add(get-tuple-element.16, constant.21) - get-tuple-element.30 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=4 - get-tuple-element.18 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=2 - get-tuple-element.19 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=3 + get-tuple-element.30 = get-tuple-element(arg_tuple.15), index=4 + get-tuple-element.18 = get-tuple-element(arg_tuple.15), index=2 + get-tuple-element.19 = get-tuple-element(arg_tuple.15), index=3 constant.23 = s32[] constant(0) compare.2 = pred[] compare(get-tuple-element.16, constant.23), direction=LT - constant.22 = s32[] constant(128) + constant.22 = s32[] constant(8) add.3 = s32[] add(get-tuple-element.16, constant.22) select.1 = s32[] select(compare.2, add.3, get-tuple-element.16) - dynamic-slice.1 = f32[1,128,128]{2,1,0} dynamic-slice(get-tuple-element.19, select.1, constant.23, constant.23), dynamic_slice_sizes={1,128,128} - bitcast.72 = f32[128,128]{1,0} bitcast(dynamic-slice.1) - get-tuple-element.17 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=1 - custom-call.1 = (f32[128,128]{1,0}, s8[131072]{0}) custom-call(bitcast.72, get-tuple-element.17), custom_call_target="__cublas$gemm" - get-tuple-element = f32[128,128]{1,0} get-tuple-element(custom-call.1), index=0 - bitcast.77 = f32[1,128,128]{2,1,0} bitcast(get-tuple-element) - dynamic-update-slice.1 = f32[128,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.18, bitcast.77, select.1, constant.23, constant.23) + dynamic-slice.1 = f32[1,8,8]{2,1,0} dynamic-slice(get-tuple-element.19, select.1, constant.23, constant.23), dynamic_slice_sizes={1,8,8} + bitcast.72 = f32[8,8]{1,0} bitcast(dynamic-slice.1) + get-tuple-element.17 = f32[8,8]{1,0} get-tuple-element(arg_tuple.15), index=1 + custom-call.1 = (f32[8,8]{1,0}, s8[131072]{0}) custom-call(bitcast.72, get-tuple-element.17), custom_call_target="__cublas$gemm" + get-tuple-element = f32[8,8]{1,0} get-tuple-element(custom-call.1), index=0 + bitcast.77 = f32[1,8,8]{2,1,0} bitcast(get-tuple-element) + dynamic-update-slice.1 = f32[8,8,8]{2,1,0} dynamic-update-slice(get-tuple-element.18, bitcast.77, select.1, constant.23, constant.23) ROOT tuple.38 = tuple(add.2, get-tuple-element.30, dynamic-update-slice.1, get-tuple-element.19, get-tuple-element.30) } // Body Cond { - arg_tuple.40 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + arg_tuple.40 = (s32[], f32[8,8]{1,0}, f32[8,8,8]{2,1,0}, f32[8,8,8]{2,1,0}, f32[8,8]{1,0}) parameter(0) get-tuple-element.41 = s32[] get-tuple-element(arg_tuple.40), index=0 - constant.46 = s32[] constant(128) + constant.46 = s32[] constant(8) ROOT compare.3 = pred[] compare(get-tuple-element.41, constant.46), direction=LT } ENTRY main { constant.4 = s32[] constant(0) - Arg_1.2 = f32[128,128]{1,0} parameter(1) + Arg_1.2 = f32[8,8]{1,0} parameter(1) constant.5 = f32[] constant(0) - broadcast.1 = f32[128,128,128]{2,1,0} broadcast(constant.5), dimensions={} - Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2) - Arg_0.1 = f32[128,128]{1,0} parameter(0) + broadcast.1 = f32[8,8,8]{2,1,0} broadcast(constant.5), dimensions={} + Arg_2.3 = f32[8,8,8]{2,1,0} parameter(2) + Arg_0.1 = f32[8,8]{1,0} parameter(0) tuple.7 = tuple(constant.4, Arg_1.2, broadcast.1, Arg_2.3, Arg_0.1) - while.48 = while(tuple.7), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"128"}} - get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(while.48), index=1 - get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(while.48), index=2 - ROOT tuple.54 = (f32[128,128]{1,0}, f32[128,128,128]{2,1,0}) tuple(get-tuple-element.50, get-tuple-element.51) + while.48 = while(tuple.7), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"8"}} + get-tuple-element.50 = get-tuple-element(while.48), index=1 + get-tuple-element.51 = get-tuple-element(while.48), index=2 + ROOT tuple.54 = tuple(get-tuple-element.50, get-tuple-element.51) } // main.55 )"; - auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); const char* expected = R"( - // CHECK: %dynamic-slice-fusion{{.*}} {{.+}} { - // CHECK: {{.+}} = {{.+}}dynamic-slice - // CHECK: {{.+}} = {{.+}}custom-call - // CHECK: {{.+}} = {{.+}}dynamic-update-slice - // CHECK: } - // CHECK: %Body{{.+}}{ - // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0) - // CHECK: %[[LOOP_ITER:.+]] = s32[] get-tuple-element(%[[PARAM]]), index=0 - // CHECK: %[[OFFSET:.+]] = s32[] select({{.+}}) - // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%dynamic-slice-fusion{{.+}}"name":"dynamic_address_computation" - // CHECK: %[[GTE:.+]] = {{.+}} get-tuple-element(%[[ADDRESS_COMPUTATION]]), index=0 - // CHECK: ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %[[GTE]], %{{.+}}) - // CHECK: } - // CHECK: ENTRY %main{{.+}}{ - // CHECK: %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"128"}} - // CHECK: } + // CHECK: %dynamic-slice-fusion{{.*}} {{.+}} { + // CHECK-DAG: %[[ITER:.+]] = s32[] parameter(4) + // CHECK-DAG: %[[OFFSET_VALUES:.+]] = s32[8]{0} constant({0, 1, 2, 3, 4, 5, 6, 7}) + // CHECK-DAG: %[[OFFSET_ARR:.+]] = s32[1]{0} dynamic-slice(%[[OFFSET_VALUES]], %[[ITER]]), dynamic_slice_sizes={1} + // CHECK-DAG: %[[OFFSET:.+]] = s32[] reshape(%[[OFFSET_ARR]]) + // CHECK-DAG: %[[DS:.+]] = f32[1,8,8]{2,1,0} dynamic-slice({{.+}}, %[[OFFSET]], {{.+}}), dynamic_slice_sizes={1,8,8} + // CHECK-DAG: %[[BITCAST:.+]] = {{.+}} bitcast(%[[DS]]) + // CHECK-DAG: %[[GEMM:.+]] = {{.+}} custom-call(%[[BITCAST]], {{.+}}), custom_call_target="__cublas$gemm" + // CHECK-DAG: %[[DUS:.+]] = {{.+}} dynamic-update-slice({{.+}}, %[[OFFSET]], {{.+}}) + // CHECK: } + // CHECK: %Body{{.+}}{ + // CHECK-DAG: %[[PARAM:.+]] = {{.+}} parameter(0) + // CHECK-DAG: %[[LOOP_ITER:.+]] = s32[] get-tuple-element(%[[PARAM]]), index=5 + // CHECK-DAG: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, %{{.+}}, %[[LOOP_ITER]]), kind=kCustom, calls=%dynamic-slice-fusion{{.+}}"name":"dynamic_address_computation" + // CHECK-DAG: %[[GTE:.+]] = {{.+}} get-tuple-element(%[[ADDRESS_COMPUTATION]]), index=0 + // CHECK-DAG: ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %[[GTE]], %{{.+}}) + // CHECK: } + // CHECK: ENTRY %main{{.+}}{ + // CHECK: %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"8"}} + // CHECK: } )"; RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); } @@ -2144,4 +2228,90 @@ TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDynamicSlice) { RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); } +// This is not a requirement from the DynamicSliceFusionRewriter, but this tests +// the current behavior so that the removal of this is intentional. +TEST_F(DynamicSliceFusionRewriterTest, ReplicaIdAndPartitionIdAsOffset) { + const char* hlo = R"( + HloModule test_module, replica_count=2, num_partitions=2 + ENTRY main { + p0 = s32[32,32] parameter(0) + p1 = s32[32,32] parameter(1) + p2 = s32[64,32] parameter(2) + c10 = u32[] constant(10) + c0 = u32[] constant(0) + + // This should get fused. + call1 = s32[32,32] custom-call(p0, p1), custom_call_target="__cublas$gemm" + dus1 = s32[64,32] dynamic-update-slice(p2, call1, c10, c0) + + // This should not get fused. + replica = u32[] replica-id() + call2 = s32[32,32] custom-call(p0, p1), custom_call_target="__cublas$gemm" + dus2 = s32[64,32] dynamic-update-slice(p2, call2, replica, c0) + + // This should not get fused. + partition = u32[] partition-id() + call3 = s32[32,32] custom-call(p0, p1), custom_call_target="__cublas$gemm" + dus3 = s32[64,32] dynamic-update-slice(p2, call3, partition, c0) + ROOT tuple = tuple(dus1, dus2, dus3) + } + )"; + + const char* expected = R"( + // CHECK: dynamic-slice-fusion{{.*}} { + // CHECK: custom-call + // CHECK: dynamic-update-slice + // CHECK: } + // CHECK: ENTRY {{.+}} { + // CHECK-DAG: %{{.+}} = {{.+}} fusion({{.+}}) + // CHECK-DAG: %[[call2:.+]] = {{.+}} custom-call({{.+}}) + // CHECK-DAG: %[[replica:.+]] = u32[] replica-id() + // CHECK-DAG: %{{.+}} = {{.+}} dynamic-update-slice({{.+}} %[[call2]], %[[replica]], {{.+}}) + // CHECK-DAG: %[[partition:.+]] = u32[] partition-id() + // CHECK-DAG: %[[call3:.+]] = {{.+}} custom-call({{.+}}) + // CHECK-DAG: %{{.+}} = {{.+}} dynamic-update-slice({{.+}} %[[call3]], %[[partition]], {{.+}}) + // CHECK-DAG: ROOT {{.+}} = {{.+}} tuple({{.+}}) + // CHECK: } + )"; + + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + +TEST_F(DynamicSliceFusionRewriterTest, ParameterOffsetThroughWhileLoop) { + const char* hlo = R"( + HloModule test + Body { + p = (s32[], s32[32,32], s32[32,32], s32[64,32], s32[]) parameter(0) + i = get-tuple-element(p), index=0 + p0 = get-tuple-element(p), index=1 + p1 = get-tuple-element(p), index=2 + p2 = s32[64,32] get-tuple-element(p), index=3 + offset = s32[] get-tuple-element(p), index=4 + c0 = s32[] constant(0) + call = s32[32,32] custom-call(p0, p1), custom_call_target="__cublas$gemm" + dus = s32[64,32] dynamic-update-slice(p2, call, offset, c0) + c1 = s32[] constant(1) + i_plus_one = add(i, c1) + ROOT tuple = tuple(i_plus_one, p1, p0, dus, offset) + } + Cond { + p = (s32[], s32[32,32], s32[32,32], s32[64,32], s32[]) parameter(0) + i = get-tuple-element(p), index=0 + c4 = s32[] constant(4) + ROOT compare = compare(i, c4), direction=LT + } + ENTRY main { + offset = s32[] parameter(0) + p0 = s32[32,32] parameter(1) + p1 = s32[32,32] parameter(2) + p2 = s32[64,32] parameter(3) + c0 = s32[] constant(0) + tuple = tuple(c0, p0, p1, p2, offset) + ROOT while = while(tuple), body=Body, condition=Cond + } + )"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), + std::nullopt); +} + } // namespace xla::gpu