diff --git a/xla/service/BUILD b/xla/service/BUILD index 5f1c0407d99fd0..89ce2748243e6a 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1218,6 +1218,7 @@ cc_library( ":hlo_domain_isolator", "//xla:status_macros", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/service/call_inliner.cc b/xla/service/call_inliner.cc index a879e560c1cd17..ff7c882dd4a5ab 100644 --- a/xla/service/call_inliner.cc +++ b/xla/service/call_inliner.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/call_inliner.h" #include +#include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "xla/service/hlo_domain_isolator.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -152,6 +154,29 @@ CallInliner::Inline(HloInstruction* call) { const auto& callees = call->called_computations(); TF_RET_CHECK(callees.size() == 1); HloComputation* callee = callees[0]; + + // Propagate the frontend attributes related to fusion from the call to the + // inlined instructions. + if (call->has_frontend_attributes()) { + const FrontendAttributes& call_attributes = call->frontend_attributes(); + std::string has_fuse = + call_attributes.map().contains("MUST_FUSE") ? "MUST_FUSE" + : call_attributes.map().contains("MAXIMAL_FUSE") ? "MAXIMAL_FUSE" + : ""; + if (!has_fuse.empty()) { + for (auto instruction : callee->instructions()) { + // Do so for only fusible instructions. + if (instruction->IsFusible()) { + FrontendAttributes frontend_attributes = + instruction->frontend_attributes(); + frontend_attributes.mutable_map()->insert( + {has_fuse, call_attributes.map().at(has_fuse)}); + instruction->set_frontend_attributes(frontend_attributes); + } + } + } + } + // We visit the callee, cloning its body into its caller. SubcomputationInsertionVisitor visitor(call); TF_RETURN_IF_ERROR(callee->Accept(&visitor)); @@ -160,7 +185,6 @@ CallInliner::Inline(HloInstruction* call) { bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const { return instruction->opcode() == HloOpcode::kCall && - !instruction->has_backend_config() && !instruction->parent()->IsAsyncComputation(); }