diff --git a/xla/service/BUILD b/xla/service/BUILD index 85ee3d9a35c84..1f92ccc1df0a6 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1221,12 +1221,14 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service/spmd/shardy:constants", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", diff --git a/xla/service/call_inliner.cc b/xla/service/call_inliner.cc index c06d9a0eaa33b..6c5550aa34014 100644 --- a/xla/service/call_inliner.cc +++ b/xla/service/call_inliner.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -33,6 +34,7 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_domain_isolator.h" +#include "xla/service/spmd/shardy/constants.h" #include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -135,6 +137,21 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_; }; +// Specific inlining rules when needing to round-trip from MLIR->HLO->MLIR when +// using Shardy (github.com/openxla/shardy). +// +// - shmap_body: We don't want to inline the bodies of JAX shard maps in order +// to import them into an `sdy.ManualComputationOp`. This is for the MHLO +// round-trip pipeline +// - kManualComputationBodyFuncName: Same as shmap_body except for the SDY +// round-trip pipeline. +bool InlineUnderShardy(HloInstruction* instruction) { + return !(instruction->GetModule()->config().use_shardy_partitioner() && + (absl::StrContains(instruction->to_apply()->name(), "shmap_body") || + absl::StartsWith(instruction->to_apply()->name(), + sdy::kManualComputationBodyFuncName.str()))); +} + } // namespace /* static */ absl::StatusOr @@ -186,7 +203,8 @@ CallInliner::Inline(HloInstruction* call) { bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const { return instruction->opcode() == HloOpcode::kCall && !instruction->has_backend_config() && - !instruction->parent()->IsAsyncComputation(); + !instruction->parent()->IsAsyncComputation() && + InlineUnderShardy(instruction); } absl::StatusOr CallInliner::Run( diff --git a/xla/service/call_inliner_test.cc b/xla/service/call_inliner_test.cc index ad6ee73eb14e8..f60fe1a0b0338 100644 --- a/xla/service/call_inliner_test.cc +++ b/xla/service/call_inliner_test.cc @@ -16,15 +16,14 @@ limitations under the License. #include "xla/service/call_inliner.h" #include -#include #include #include "absl/log/log.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo_parser.h" #include "xla/shape.h" @@ -377,5 +376,89 @@ TEST_F(CallInlinerTest, InlineCompositeCall) { EXPECT_TRUE((*inst)->frontend_attributes().map().empty()); } +TEST_F(CallInlinerTest, UseShardyMhloToHloShmapBodyNotInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %prefix_shmap_body_suffix.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.2 = f32[8,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="Sharding", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=3} + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_shmap_body_suffix.4 + %custom-call.8 = f32[1,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="Sharding", sharding={manual}, metadata={source_file="-" source_line=6} + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // The single call in the module is not inlined. + EXPECT_FALSE(changed); + + HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); + EXPECT_NE(call, nullptr); + EXPECT_TRUE(call->has_to_apply()); + EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4"); +} + +// Don't inline when the name starts with "xla.sdy.manual_computation_body". +TEST_F(CallInlinerTest, UseShardManualComputationBodyNotInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4 + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // The single call in the module is not inlined. + EXPECT_FALSE(changed); + + HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); + EXPECT_NE(call, nullptr); + EXPECT_TRUE(call->has_to_apply()); + EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4"); +} + +// Inliner only checks if the name of the function has +// "xla.sdy.manual_computation_body" a prefix, not if it contains it. +TEST_F(CallInlinerTest, UseShardManualComputationBodyInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %prefix_xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_xla.sdy.manual_computation_body.4 + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // Will be inlined. + EXPECT_TRUE(changed); +} + } // namespace } // namespace xla diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 70291be472ec8..c782a9e4528a1 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1465,7 +1465,6 @@ cc_library( "//xla/service/gpu/transforms:triton_fusion_numerics_verifier", "//xla/service/gpu/transforms:windowed_einsum_handler", "//xla/service/llvm_ir:llvm_util", - "//xla/service/spmd/shardy:shardy_call_inliner", "//xla/service/spmd:collective_permute_motion", "//xla/service:algebraic_simplifier", "//xla/service:all_gather_broadcast_reorder", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index afb1002cc1946..fb1b8fd0d50ab 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -224,7 +224,6 @@ limitations under the License. #include "xla/service/slice_sinker.h" #include "xla/service/slow_operation_alarm.h" #include "xla/service/sort_simplifier.h" -#include "xla/service/spmd/shardy/shardy_call_inliner.h" #include "xla/service/stable_sort_expander.h" #include "xla/service/stochastic_convert_decomposer.h" #include "xla/service/sub_byte_normalization.h" @@ -552,7 +551,7 @@ absl::Status RunPreSPMDPartitionerPasses(HloModule* hlo_module) { // passes. pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); - pre_spmd_pipeline.AddPass(); + pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); @@ -709,7 +708,7 @@ absl::Status RunOptimizationPasses( pipeline.AddPass(); // TODO(b/64094172): make Call work on GPU instead of inlining. - pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); @@ -1594,7 +1593,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( options.key_value_store, gpu_target_config.device_description.runtime_version())); // Inline back the calls which have better performance with cuBLAS. - pipeline.AddPass(); + pipeline.AddPass(); // TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated // here for possibly better cuBLAS performance. AddGemmRewriterPasses(pipeline, debug_options, gpu_version, diff --git a/xla/service/spmd/shardy/BUILD b/xla/service/spmd/shardy/BUILD index f516384ada212..74dbf1f3d445c 100644 --- a/xla/service/spmd/shardy/BUILD +++ b/xla/service/spmd/shardy/BUILD @@ -23,31 +23,6 @@ package_group( ], ) -cc_library( - name = "shardy_call_inliner", - srcs = ["shardy_call_inliner.cc"], - hdrs = ["shardy_call_inliner.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:call_inliner", - "//xla/service/spmd/shardy:constants", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "shardy_call_inliner_test", - srcs = ["shardy_call_inliner_test.cc"], - deps = [ - ":shardy_call_inliner", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/log", - "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:statusor", - ], -) - cc_library( name = "shardy_xla_pass", srcs = ["shardy_xla_pass.cc"], diff --git a/xla/service/spmd/shardy/shardy_call_inliner.cc b/xla/service/spmd/shardy/shardy_call_inliner.cc deleted file mode 100644 index 2de735c98ecbb..0000000000000 --- a/xla/service/spmd/shardy/shardy_call_inliner.cc +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/spmd/shardy/shardy_call_inliner.h" - -#include "absl/strings/match.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/call_inliner.h" -#include "xla/service/spmd/shardy/constants.h" - -namespace xla { - -bool ShardyCallInliner::IsInlineableCallOp(HloInstruction* instruction) const { - return CallInliner::IsInlineableCallOp(instruction) && - !(instruction->GetModule()->config().use_shardy_partitioner() && - (absl::StrContains(instruction->to_apply()->name(), "shmap_body") || - absl::StartsWith(instruction->to_apply()->name(), - sdy::kManualComputationBodyFuncName))); -} - -} // namespace xla diff --git a/xla/service/spmd/shardy/shardy_call_inliner.h b/xla/service/spmd/shardy/shardy_call_inliner.h deleted file mode 100644 index 9dbc52682a60a..0000000000000 --- a/xla/service/spmd/shardy/shardy_call_inliner.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_SPMD_SHARDY_SHARDY_CALL_INLINER_H_ -#define XLA_SERVICE_SPMD_SHARDY_SHARDY_CALL_INLINER_H_ - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/call_inliner.h" - -namespace xla { - -// The same as CallInliner, except as part of -// go/jax-shmap -> `sdy.ManualComputationOp` importing, we require the pattern -// in MHLO: -// ``` -// %shard_arg0_0 = custom_call @Sharding(%0) -// %shard_arg0_1 = custom_call @SPMDFullToShardShape(%shard_arg0_0) -// ... -// %shard_argN_0 = custom_call @Sharding(%N) -// %shard_argN_1 = custom_call @SPMDFullToShardShape(%shard_argN_0) -// -// %shard_result0, ..., %shard_resultN = func.call @shmap_body(%shard_arg0_1, -// ..., -// %shard_argN_1) -// -// %shard_result0_0 = custom_call @Sharding(%shard_result0) -// %shard_result0_1 = custom_call @SPMDShardToFullShape(%shard_result0_0) -// ... -// %shard_resultN_0 = custom_call @Sharding(%shard_resultN) -// %shard_resultN_1 = custom_call @SPMDShardToFullShape(%shard_resultN_0) -// ``` -// We specifically match on the `func.call @shmap_body` since we want to inline -// the body of that function into the `ManualComputationOp` body. So this makes -// sure we inline all functions except for the shmap_body's when using -// Shardy. When Shardy is disabled, then we have the same behavior as -// CallInliner. -// -// TODO(bartchr): Move the logic in here into the regular XLA `CallInliner`. -// Shardy is now proven out so we should have the parent `CallInliner` handle -// this. -class ShardyCallInliner : public CallInliner { - public: - using CallInliner::CallInliner; - absl::string_view name() const override { return "shardy-call-inliner"; } - - bool IsInlineableCallOp(HloInstruction* instruction) const override; -}; - -} // namespace xla - -#endif // XLA_SERVICE_SPMD_SHARDY_SHARDY_CALL_INLINER_H_ diff --git a/xla/service/spmd/shardy/shardy_call_inliner_test.cc b/xla/service/spmd/shardy/shardy_call_inliner_test.cc deleted file mode 100644 index b2055e59d75c3..0000000000000 --- a/xla/service/spmd/shardy/shardy_call_inliner_test.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/spmd/shardy/shardy_call_inliner.h" - -#include -#include "absl/log/log.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace sdy { - -using ShardyCallInlinerTest = xla::HloTestBase; - -TEST_F(ShardyCallInlinerTest, MhloToHloShmapBodyNotInlined) { - const char* const hloString = R"( - HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} - - %prefix_shmap_body_suffix.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { - %Arg_0.5 = f32[1,8]{1,0} parameter(0) - ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} - } - - ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { - %Arg_0.1 = f32[8,8]{1,0} parameter(0) - %custom-call.2 = f32[8,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="Sharding", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=3} - %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} - %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_shmap_body_suffix.4 - %custom-call.8 = f32[1,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="Sharding", sharding={manual}, metadata={source_file="-" source_line=6} - ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); - module->mutable_config().set_use_shardy_partitioner(true); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); - VLOG(1) << module->ToString(); - // The single call in the module is not inlined. - EXPECT_FALSE(changed); - - HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); - EXPECT_NE(call, nullptr); - EXPECT_TRUE(call->has_to_apply()); - EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4"); -} - -// Don't inline when the name starts with "xla.sdy.manual_computation_body". -TEST_F(ShardyCallInlinerTest, ManualComputationBodyNotInlined) { - const char* const hloString = R"( - HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} - - %xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { - %Arg_0.5 = f32[1,8]{1,0} parameter(0) - ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} - } - - ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { - %Arg_0.1 = f32[8,8]{1,0} parameter(0) - %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} - %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4 - ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); - module->mutable_config().set_use_shardy_partitioner(true); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); - VLOG(1) << module->ToString(); - // The single call in the module is not inlined. - EXPECT_FALSE(changed); - - HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); - EXPECT_NE(call, nullptr); - EXPECT_TRUE(call->has_to_apply()); - EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4"); -} - -// Inliner only checks if the name of the function has -// "xla.sdy.manual_computation_body" a prefix, not if it contains it. -TEST_F(ShardyCallInlinerTest, ManualComputationBodyInlined) { - const char* const hloString = R"( - HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} - - %prefix_xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { - %Arg_0.5 = f32[1,8]{1,0} parameter(0) - ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} - } - - ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { - %Arg_0.1 = f32[8,8]{1,0} parameter(0) - %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} - %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_xla.sdy.manual_computation_body.4 - ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); - module->mutable_config().set_use_shardy_partitioner(true); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); - VLOG(1) << module->ToString(); - // Will be inlined. - EXPECT_TRUE(changed); -} - -} // namespace sdy -} // namespace xla