Skip to content

Commit

Permalink
Move BatchedGatherScatterNormalizer from pre-SPMD for pose-SPMD.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695816113
  • Loading branch information
bixia1 authored and Google-ML-Automation committed Nov 12, 2024
1 parent f173440 commit 210667d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 9 deletions.
7 changes: 1 addition & 6 deletions xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,6 @@ void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {},
absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
HloModule* module, bool is_aot_compile,
LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile) {
HloPassPipeline pre_sharding_pipeline("pre-spmd-pipeline");
// TODO(b/359982037): Run BatchedGatherScatterNormalizer after partitioning.
pre_sharding_pipeline.AddPass<BatchedGatherScatterNormalizer>();
TF_RETURN_IF_ERROR(pre_sharding_pipeline.Run(module).status());

const int64_t num_partitions = module->config().num_partitions();
if (num_partitions > 1) {
if (!module->config().use_spmd_partitioning()) {
Expand Down Expand Up @@ -527,7 +522,7 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
}
HloPassPipeline pipeline("HLO passes through layout assignment");
AddHloVerifier(&pipeline);

pipeline.AddPass<BatchedGatherScatterNormalizer>();
pipeline.AddPass<ResultCaster>();
pipeline.AddPass<OperandUpcaster>();

Expand Down
4 changes: 1 addition & 3 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,6 @@ AlgebraicSimplifierOptions LayoutInsensitiveAlgebraicSimplifierOptions(

absl::Status RunPreSPMDPartitionerPasses(HloModule* hlo_module) {
HloPassPipeline pre_spmd_pipeline("pre-spmd-partitioner");
// TODO(b/359982037): Run BatchedGatherScatterNormalizer after partitioning.

pre_spmd_pipeline.AddPass<BatchedGatherScatterNormalizer>();
// Run some IR cleanup passes before running the SPMD partitioning
// passes.
pre_spmd_pipeline.AddPass<CuDnnCustomCallConverter>();
Expand Down Expand Up @@ -665,6 +662,7 @@ absl::Status RunOptimizationPasses(
HloPassPipeline pipeline("optimization");
AddHloVerifier(&pipeline,
!debug_options.xla_experimental_ignore_channel_id());
pipeline.AddPass<BatchedGatherScatterNormalizer>();
if (debug_options.xla_gpu_multi_streamed_windowed_einsum()) {
pipeline.AddPass<WindowedEinsumHandler>();
}
Expand Down

0 comments on commit 210667d

Please sign in to comment.