Skip to content

Commit

Permalink
#sdy Merge XLA CallInliner and ShardyCallInliner.
Browse files Browse the repository at this point in the history
Now that Shardy will now be fully integrated, we should delete the `ShardyCallInliner` and update `CallInliner` to look for what `ShardyCallInliner` checks for. We've had two bugs because of this thus far.

PiperOrigin-RevId: 681791276
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Oct 3, 2024
1 parent 25a0df2 commit e5c780d
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 244 deletions.
2 changes: 2 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1221,12 +1221,14 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/service/spmd/shardy:constants",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
Expand Down
20 changes: 19 additions & 1 deletion xla/service/call_inliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -33,6 +34,7 @@ limitations under the License.
#include "xla/service/call_graph.h"
#include "xla/service/hlo_dce.h"
#include "xla/service/hlo_domain_isolator.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -135,6 +137,21 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_;
};

// Specific inlining rules when needing to round-trip from MLIR->HLO->MLIR when
// using Shardy (github.com/openxla/shardy).
//
// - shmap_body: We don't want to inline the bodies of JAX shard maps in order
// to import them into an `sdy.ManualComputationOp`. This is for the MHLO
// round-trip pipeline
// - kManualComputationBodyFuncName: Same as shmap_body except for the SDY
// round-trip pipeline.
bool InlineUnderShardy(HloInstruction* instruction) {
return !(instruction->GetModule()->config().use_shardy_partitioner() &&
(absl::StrContains(instruction->to_apply()->name(), "shmap_body") ||
absl::StartsWith(instruction->to_apply()->name(),
sdy::kManualComputationBodyFuncName.str())));
}

} // namespace

/* static */ absl::StatusOr<CallInliner::InlinedInstructionMap>
Expand Down Expand Up @@ -186,7 +203,8 @@ CallInliner::Inline(HloInstruction* call) {
bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
return instruction->opcode() == HloOpcode::kCall &&
!instruction->has_backend_config() &&
!instruction->parent()->IsAsyncComputation();
!instruction->parent()->IsAsyncComputation() &&
InlineUnderShardy(instruction);
}

absl::StatusOr<bool> CallInliner::Run(
Expand Down
87 changes: 85 additions & 2 deletions xla/service/call_inliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ limitations under the License.
#include "xla/service/call_inliner.h"

#include <cstdint>
#include <memory>
#include <string>

#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/hlo_parser.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -377,5 +376,89 @@ TEST_F(CallInlinerTest, InlineCompositeCall) {
EXPECT_TRUE((*inst)->frontend_attributes().map().empty());
}

TEST_F(CallInlinerTest, UseShardyMhloToHloShmapBodyNotInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}
%prefix_shmap_body_suffix.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}
ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.2 = f32[8,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="Sharding", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=3}
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_shmap_body_suffix.4
%custom-call.8 = f32[1,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="Sharding", sharding={manual}, metadata={source_file="-" source_line=6}
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get()));
VLOG(1) << module->ToString();
// The single call in the module is not inlined.
EXPECT_FALSE(changed);

HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall);
EXPECT_NE(call, nullptr);
EXPECT_TRUE(call->has_to_apply());
EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4");
}

// Don't inline when the name starts with "xla.sdy.manual_computation_body".
TEST_F(CallInlinerTest, UseShardManualComputationBodyNotInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}
%xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}
ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get()));
VLOG(1) << module->ToString();
// The single call in the module is not inlined.
EXPECT_FALSE(changed);

HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall);
EXPECT_NE(call, nullptr);
EXPECT_TRUE(call->has_to_apply());
EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4");
}

// Inliner only checks if the name of the function has
// "xla.sdy.manual_computation_body" a prefix, not if it contains it.
TEST_F(CallInlinerTest, UseShardManualComputationBodyInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}
%prefix_xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}
ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_xla.sdy.manual_computation_body.4
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get()));
VLOG(1) << module->ToString();
// Will be inlined.
EXPECT_TRUE(changed);
}

} // namespace
} // namespace xla
1 change: 0 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,6 @@ cc_library(
"//xla/service/gpu/transforms:triton_fusion_numerics_verifier",
"//xla/service/gpu/transforms:windowed_einsum_handler",
"//xla/service/llvm_ir:llvm_util",
"//xla/service/spmd/shardy:shardy_call_inliner",
"//xla/service/spmd:collective_permute_motion",
"//xla/service:algebraic_simplifier",
"//xla/service:all_gather_broadcast_reorder",
Expand Down
7 changes: 3 additions & 4 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ limitations under the License.
#include "xla/service/slice_sinker.h"
#include "xla/service/slow_operation_alarm.h"
#include "xla/service/sort_simplifier.h"
#include "xla/service/spmd/shardy/shardy_call_inliner.h"
#include "xla/service/stable_sort_expander.h"
#include "xla/service/stochastic_convert_decomposer.h"
#include "xla/service/sub_byte_normalization.h"
Expand Down Expand Up @@ -552,7 +551,7 @@ absl::Status RunPreSPMDPartitionerPasses(HloModule* hlo_module) {
// passes.
pre_spmd_pipeline.AddPass<CuDnnCustomCallConverter>();
pre_spmd_pipeline.AddPass<ConvertMemoryPlacementToInternalAnnotations>();
pre_spmd_pipeline.AddPass<ShardyCallInliner>();
pre_spmd_pipeline.AddPass<CallInliner>();
pre_spmd_pipeline.AddPass<ZeroSizedHloElimination>();
pre_spmd_pipeline.AddPass<ConditionalCanonicalizer>();

Expand Down Expand Up @@ -709,7 +708,7 @@ absl::Status RunOptimizationPasses(
pipeline.AddPass<DynamicIndexSplitter>();

// TODO(b/64094172): make Call work on GPU instead of inlining.
pipeline.AddPass<ShardyCallInliner>();
pipeline.AddPass<CallInliner>();

pipeline.AddPass<StochasticConvertDecomposer>();

Expand Down Expand Up @@ -1594,7 +1593,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
options.key_value_store,
gpu_target_config.device_description.runtime_version()));
// Inline back the calls which have better performance with cuBLAS.
pipeline.AddPass<ShardyCallInliner>();
pipeline.AddPass<CallInliner>();
// TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated
// here for possibly better cuBLAS performance.
AddGemmRewriterPasses(pipeline, debug_options, gpu_version,
Expand Down
25 changes: 0 additions & 25 deletions xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,6 @@ package_group(
],
)

cc_library(
name = "shardy_call_inliner",
srcs = ["shardy_call_inliner.cc"],
hdrs = ["shardy_call_inliner.h"],
deps = [
"//xla/hlo/ir:hlo",
"//xla/service:call_inliner",
"//xla/service/spmd/shardy:constants",
"@com_google_absl//absl/strings",
],
)

xla_cc_test(
name = "shardy_call_inliner_test",
srcs = ["shardy_call_inliner_test.cc"],
deps = [
":shardy_call_inliner",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/log",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)

cc_library(
name = "shardy_xla_pass",
srcs = ["shardy_xla_pass.cc"],
Expand Down
33 changes: 0 additions & 33 deletions xla/service/spmd/shardy/shardy_call_inliner.cc

This file was deleted.

63 changes: 0 additions & 63 deletions xla/service/spmd/shardy/shardy_call_inliner.h

This file was deleted.

Loading

0 comments on commit e5c780d

Please sign in to comment.