Skip to content

Commit

Permalink
#sdy Move shard map import pass to mhlo_round_trip.
Browse files Browse the repository at this point in the history
In a follow-up I will add a specific shard map import pass for sdy round tripping.

PiperOrigin-RevId: 674266920
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Sep 13, 2024
1 parent 67325a0 commit 8e91a0a
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 46 deletions.
2 changes: 1 addition & 1 deletion xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ xla_cc_binary(
"//xla/service/spmd/shardy/mhlo_round_trip:mhlo_export",
"//xla/service/spmd/shardy/mhlo_round_trip:mhlo_import",
"//xla/service/spmd/shardy/mhlo_round_trip:shard_map_export",
"//xla/service/spmd/shardy/mhlo_round_trip:shard_map_import",
"//xla/service/spmd/shardy/round_trip_common:convert_sharding_custom_calls",
"//xla/service/spmd/shardy/round_trip_common:import_constants",
"//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding",
"//xla/service/spmd/shardy/round_trip_common:shard_map_import",
"//xla/service/spmd/shardy/sdy_round_trip:export_ops",
"//xla/service/spmd/shardy/sdy_round_trip:export_shardings",
"//xla/service/spmd/shardy/sdy_round_trip:import_shardings",
Expand Down
26 changes: 26 additions & 0 deletions xla/service/spmd/shardy/mhlo_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ cc_library(
srcs = ["mhlo_import.cc"],
hdrs = ["mhlo_import.h"],
deps = [
":shard_map_import",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand All @@ -119,3 +120,28 @@ cc_library(
"@shardy//shardy/dialect/sdy/ir:dialect",
],
)

cc_library(
name = "shard_map_import",
srcs = ["shard_map_import.cc"],
hdrs = ["shard_map_import.h"],
deps = [
"//xla:xla_data_proto_cc",
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy:constants",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
],
)
2 changes: 2 additions & 0 deletions xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ limitations under the License.
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h"
#include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h"
#include "xla/translate/mhlo_to_hlo/attribute_exporter.h"
#include "xla/util.h"
Expand Down Expand Up @@ -647,6 +648,7 @@ void addMhloImportPipeline(mlir::OpPassManager& pm,
addCommonPreImportPasses(pm);
pm.addPass(createImportShardingsPass(allowPropagationToArgs,
allowPropagationToResults));
pm.addPass(createMhloRoundTripShardMapImportPass());
addCommonPostImportPasses(pm);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h"

#include <cstdint>
#include <memory>
Expand Down Expand Up @@ -503,11 +503,13 @@ class ShardMapImportPass
}
}

StringRef getArgument() const override { return "xla-sdy-shard-map-import"; }
StringRef getArgument() const override {
return "xla-mhlo-round-trip-shard-map-import";
}

StringRef getDescription() const override {
return "Replaces a CallOp pattern unique to JAX shard_map with a "
"ManualComputationOp.";
return "Replaces a CallOp pattern unique to JAX shard_map through GSPMD "
"lowering with a ManualComputationOp.";
}

void getDependentDialects(mlir::DialectRegistry& registry) const final {
Expand All @@ -517,12 +519,12 @@ class ShardMapImportPass

} // namespace

std::unique_ptr<mlir::Pass> createShardMapImportPass() {
std::unique_ptr<mlir::Pass> createMhloRoundTripShardMapImportPass() {
return std::make_unique<ShardMapImportPass>();
}

void registerShardMapImportPass() {
mlir::registerPass(createShardMapImportPass);
void registerMhloRoundTripShardMapImportPass() {
mlir::registerPass(createMhloRoundTripShardMapImportPass);
}

} // namespace sdy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_SHARD_MAP_IMPORT_H_
#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_SHARD_MAP_IMPORT_H_
#ifndef XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_SHARD_MAP_IMPORT_H_
#define XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_SHARD_MAP_IMPORT_H_

#include <memory>

Expand All @@ -26,12 +26,12 @@ namespace sdy {

// Creates a pass that converts the code pattern generated by JAX's shard_map
// into the `ManualComputationOp`.
std::unique_ptr<mlir::Pass> createShardMapImportPass();
std::unique_ptr<mlir::Pass> createMhloRoundTripShardMapImportPass();

// Registers the xla-sdy-shard-map-import pass.
void registerShardMapImportPass();
// Registers the xla-mhlo-round-trip-shard-map-import pass.
void registerMhloRoundTripShardMapImportPass();

} // namespace sdy
} // namespace xla

#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_SHARD_MAP_IMPORT_H_
#endif // XLA_SERVICE_SPMD_SHARDY_MHLO_ROUND_TRIP_SHARD_MAP_IMPORT_H_
26 changes: 0 additions & 26 deletions xla/service/spmd/shardy/round_trip_common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,31 +63,6 @@ cc_library(
],
)

cc_library(
name = "shard_map_import",
srcs = ["shard_map_import.cc"],
hdrs = ["shard_map_import.h"],
deps = [
"//xla:xla_data_proto_cc",
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy:constants",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
],
)

cc_library(
name = "pipeline_passes",
srcs = ["pipeline_passes.cc"],
Expand All @@ -96,7 +71,6 @@ cc_library(
":convert_sharding_custom_calls",
":import_constants",
":open_while_free_vars_sharding",
":shard_map_import",
"//xla/mlir_hlo:mhlo_passes",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:Pass",
Expand Down
2 changes: 0 additions & 2 deletions xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.
#include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h"
#include "xla/service/spmd/shardy/round_trip_common/import_constants.h"
#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
#include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h"

namespace xla {
namespace sdy {
Expand Down Expand Up @@ -52,7 +51,6 @@ void addCommonPreImportPasses(mlir::OpPassManager& pm) {
}

void addCommonPostImportPasses(mlir::OpPassManager& pm) {
pm.addPass(createShardMapImportPass());
pm.addPass(createConvertShardingCustomCallsPass());
pm.addNestedPass<FuncOp>(createOpenWhileFreeVarsShardingPass());
}
Expand Down
4 changes: 2 additions & 2 deletions xla/service/spmd/shardy/sdy_opt_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ limitations under the License.
#include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h"
#include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h"
#include "xla/service/spmd/shardy/round_trip_common/import_constants.h"
#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h"
#include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h"
Expand All @@ -53,7 +53,7 @@ int main(int argc, char** argv) {

xla::sdy::registerMhloImportPipeline();
xla::sdy::registerMhloImportShardingsPass();
xla::sdy::registerShardMapImportPass();
xla::sdy::registerMhloRoundTripShardMapImportPass();
xla::sdy::registerConvertShardingCustomCallsPass();
xla::sdy::registerOpenWhileFreeVarsShardingPass();
xla::sdy::registerImportConstantsPass();
Expand Down
1 change: 1 addition & 0 deletions xla/service/spmd/shardy/sdy_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ cc_library(
":import_shardings",
"//xla/service:hlo_proto_cc",
"//xla/service/spmd/shardy/mhlo_round_trip:export_shardings",
"//xla/service/spmd/shardy/mhlo_round_trip:shard_map_import",
"//xla/service/spmd/shardy/round_trip_common:pipeline_passes",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
Expand Down
3 changes: 3 additions & 0 deletions xla/service/spmd/shardy/sdy_round_trip/pipelines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h"
#include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h"
Expand Down Expand Up @@ -49,6 +50,8 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) {
void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) {
addCommonPreImportPasses(pm);
pm.addPass(createSdyRoundTripImportShardingsPass());
// TODO(bartchr): replace with an sdy round trip shard map pass.
pm.addPass(createMhloRoundTripShardMapImportPass());
addCommonPostImportPasses(pm);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: sdy_opt %s -xla-sdy-shard-map-import 2>&1 | FileCheck %s
// RUN: sdy_opt %s -xla-mhlo-round-trip-shard-map-import 2>&1 | FileCheck %s

sdy.mesh @mesh_0 = <["a"=4]>
sdy.mesh @mesh_1 = <["a"=4, "b"=2]>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: sdy_opt %s -xla-sdy-shard-map-import -split-input-file -verify-diagnostics
// RUN: sdy_opt %s -xla-mhlo-round-trip-shard-map-import -split-input-file -verify-diagnostics

sdy.mesh @mesh_1 = <["a"=4, "b"=2]>
sdy.mesh @mesh_2 = <["a"=4, "b"=2, "c"=3]>
Expand Down

0 comments on commit 8e91a0a

Please sign in to comment.