diff --git a/xla/service/spmd/shardy/BUILD b/xla/service/spmd/shardy/BUILD index 538d947f102e4..ce78ced339ab7 100644 --- a/xla/service/spmd/shardy/BUILD +++ b/xla/service/spmd/shardy/BUILD @@ -100,8 +100,10 @@ cc_library( deps = [ ":constants", "//xla/mlir_hlo", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", @@ -152,6 +154,7 @@ xla_cc_binary( "//xla/service/spmd/shardy/sdy_round_trip:import_shardings", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", "//xla/service/spmd/shardy/sdy_round_trip:shard_map_export", + "//xla/service/spmd/shardy/sdy_round_trip:shard_map_import", "//xla/service/spmd/shardy/sdy_round_trip/test_utils:mhlo_to_hlo_to_mhlo", "//xla/service/spmd/shardy/sdy_round_trip/test_utils:testing_pipeline", "@llvm-project//mlir:AllPassesAndDialects", diff --git a/xla/service/spmd/shardy/sdy_opt_main.cc b/xla/service/spmd/shardy/sdy_opt_main.cc index 803f78b2cec7c..892dd9b66a085 100644 --- a/xla/service/spmd/shardy/sdy_opt_main.cc +++ b/xla/service/spmd/shardy/sdy_opt_main.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h" +#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h" #include "xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.h" #include "xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.h" @@ -70,6 +71,7 @@ int main(int argc, char** argv) { xla::sdy::registerSdyRoundTripExportOpsPass(); xla::sdy::registerSdyRoundTripExportPipeline(); xla::sdy::registerSdyRoundTripShardMapExportPass(); + xla::sdy::registerSdyRoundTripShardMapImportPass(); xla::sdy::registerSdyRoundTripImportPipeline(); xla::sdy::registerSdyRoundTripTestingPipeline(); diff --git a/xla/service/spmd/shardy/sdy_round_trip/BUILD b/xla/service/spmd/shardy/sdy_round_trip/BUILD index 0bd38a5de406d..1f9ae760bf877 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -90,6 +90,27 @@ cc_library( ], ) +cc_library( + name = "shard_map_import", + srcs = ["shard_map_import.cc"], + hdrs = ["shard_map_import.h"], + deps = [ + "//xla/mlir_hlo", + "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "pipelines", srcs = ["pipelines.cc"], diff --git a/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc b/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc index cb8febef2e0fe..eb11cc53f1c45 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc @@ -75,26 +75,6 @@ using ::mlir::sdy::MeshAttr; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; -// Parses `stringAttr` to an attribute of type `AttrTy`. -// -// NOTE: assumes `stringAttr` is of type `StringAttr`. -template -AttrTy parseStringAttr(Attribute stringAttr) { - std::string value; - std::string error; - CHECK(absl::CUnescape(mlir::cast(stringAttr).getValue(), &value, - &error)) - << error; - return mlir::cast( - mlir::parseAttribute(value, stringAttr.getContext())); -} - -// Parses `attrName` from `dictAttr` to an attribute of type `AttrTy`. -template -AttrTy parseStringAttr(DictionaryAttr dictAttr, llvm::StringRef attrName) { - return parseStringAttr(dictAttr.get(attrName)); -} - // Builds the shardings coming from Shardy previously. This means // the module was exported from Shardy and we are now round-tripping back. // This should happen after the meshes were created from the `ModuleOp` attrs diff --git a/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc b/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc new file mode 100644 index 0000000000000..be663e9c02c69 --- /dev/null +++ b/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc @@ -0,0 +1,164 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h" + +#include +#include + +#include "absl/log/check.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/DialectConversion.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/utils.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::MLIRContext; +using ::mlir::ModuleOp; +using ::mlir::OpConversionPattern; +using ::mlir::StringRef; +using ::mlir::SymbolTable; +using ::mlir::func::FuncOp; +using ::mlir::stablehlo::CustomCallOp; + +namespace sdy = ::mlir::sdy; + +// Converts `CustomCallOp`s called `@xla.sdy.ManualComputation` with in/out +// shardings and manual axes as frontend attrs to `ManualComputationOp`s. +class ManualComputationPattern : public OpConversionPattern { + public: + explicit ManualComputationPattern(MLIRContext* context, + const SymbolTable& symbolTable) + : OpConversionPattern(context), symbolTable(symbolTable) {} + + mlir::LogicalResult matchAndRewrite( + CustomCallOp customCallOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) const override { + if (customCallOp.getCallTargetName() != + kManualComputationCustomCallTargetName) { + return mlir::failure(); + } + + CHECK_EQ(customCallOp.getCalledComputations().size(), 1); + auto shmapBodyFunc = + symbolTable.lookup((*customCallOp.getCalledComputations() + .getAsRange() + .begin()) + .getValue()); + if (shmapBodyFunc.empty()) { + return customCallOp->emitOpError( + "expected a unique FuncOp per " + "@xla.sdy.ManualComputation custom call. Were " + "functions maybe somehow shared/de-duped between " + "two ManualComputations?"); + } + + mlir::DictionaryAttr frontendAttrs = getFrontendAttrs(customCallOp); + CHECK(frontendAttrs); + auto manualComputationOp = + rewriter.replaceOpWithNewOp( + customCallOp, customCallOp->getResultTypes(), + customCallOp->getOperands(), + parseStringAttr(frontendAttrs, + kInShardings), + parseStringAttr(frontendAttrs, + kOutShardings), + parseStringAttr(frontendAttrs, kManualAxes)); + sdy::inlineRegionAndConvertTerminatorOp( + shmapBodyFunc.getBody(), manualComputationOp.getRegion(), rewriter); + rewriter.eraseOp(shmapBodyFunc); + return mlir::success(); + } + + private: + const SymbolTable& symbolTable; +}; + +class SdyRoundTripShardMapImportPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SdyRoundTripShardMapImportPass) + + private: + void runOnOperation() final { + ModuleOp module = getOperation(); + mlir::SymbolTableCollection symbolTableCollection; + SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(module); + MLIRContext& context = getContext(); + mlir::ConversionTarget target(context); + target.addDynamicallyLegalOp([](CustomCallOp op) { + return op.getCallTargetName() != kManualComputationCustomCallTargetName; + }); + target.addLegalOp(); + mlir::RewritePatternSet patterns(&context); + patterns.add(&context, symbolTable); + if (mlir::failed(mlir::applyPartialConversion(module, target, + std::move(patterns)))) { + signalPassFailure(); + } + } + + StringRef getArgument() const override { + return "xla-sdy-round-trip-shard-map-import"; + } + + StringRef getDescription() const override { + return "converts CustomCalls called @xla.sdy.manual_computation_body " + "with in/out shardings and manual axes as frontend attrs to a " + "`ManualComputationOp`"; + } + void getDependentDialects(mlir::DialectRegistry& registry) const final { + registry.insert(); + } +}; + +} // namespace + +void registerSdyRoundTripShardMapImportPass() { + mlir::registerPass(createSdyRoundTripShardMapImportPass); +} + +std::unique_ptr createSdyRoundTripShardMapImportPass() { + return std::make_unique(); +} + +} // namespace sdy +} // namespace xla diff --git a/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h b/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h new file mode 100644 index 0000000000000..1520c8baa663f --- /dev/null +++ b/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_SHARD_MAP_IMPORT_H_ +#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_SHARD_MAP_IMPORT_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace xla { +namespace sdy { + +// Creates the pass that converts a `CustomCallOp` called +// `kManualComputationBodyFuncName` with in/out shardings and manual +// axes as frontend attrs to a `ManualComputationOp`. +std::unique_ptr createSdyRoundTripShardMapImportPass(); + +// Registers the xla-sdy-round-trip-shard-map-import pass. +void registerSdyRoundTripShardMapImportPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_SHARD_MAP_IMPORT_H_ diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir new file mode 100644 index 0000000000000..3ed637d7273ff --- /dev/null +++ b/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir @@ -0,0 +1,139 @@ +// RUN: sdy_opt %s -xla-sdy-round-trip-shard-map-import 2>&1 | FileCheck %s + +sdy.mesh @mesh_0 = <["a"=4, "b"=2]> +sdy.mesh @mesh_1 = <["a"=2, "b"=2, "c"=2, "d"=2]> + +func.func @single_manual_comp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) -> (tensor<8x32xf32>) { + // CHECK-NOT: xla.sdy.ManualComputation + // CHECK: %[[MAN_COMP:.*]] = sdy.manual_computation(%arg0, %arg1) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {"b"}]>, <@mesh_0, [{"b"}, {}], replicated={"a"}>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{"a"}, {}], replicated={"b"}>] + // CHECK-SAME{LITERAL}: manual_axes={"a", "b"} + // CHECK-SAME: (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %arg2, %arg2 : tensor<2x8xf32> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[ADD_0]], %arg3 : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + // CHECK-NEXT: %[[REDUCE:.*]] = "stablehlo.all_reduce"(%[[DOT]]) + // CHECK-NEXT: ^bb0(%arg4: tensor, %arg5: tensor): + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %arg4, %arg5 : tensor + // CHECK-NEXT: stablehlo.return %[[ADD_1]] : tensor + // CHECK-NEXT: }) : (tensor<2x32xf32>) -> tensor<2x32xf32> + // CHECK-NEXT: sdy.return %[[REDUCE]] : tensor<2x32xf32> + // CHECK-NEXT: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: return %[[MAN_COMP]] : tensor<8x32xf32> + %0 = stablehlo.custom_call @xla.sdy.ManualComputation(%arg0, %arg1) {called_computations = [@xla.sdy.manual_computation_body], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> + return %0 : tensor<8x32xf32> +} + +func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK-NOT: xla.sdy.ManualComputation + // CHECK: %[[MAN_COMP_0:.*]] = sdy.manual_computation(%arg0) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{"a"}, {}]>] + // CHECK-SAME{LITERAL}: manual_axes={"a"} + // CHECK-SAME: (%arg1: tensor<2x8xf32>) { + // CHECK-NEXT: sdy.return %arg1 : tensor<2x8xf32> + // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %[[MAN_COMP_1:.*]] = sdy.manual_computation(%[[MAN_COMP_0]]) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + // CHECK-SAME: (%arg1: tensor<8x4xf32>) { + // CHECK-NEXT: sdy.return %arg1 : tensor<8x4xf32> + // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: return %[[MAN_COMP_1]] : tensor<8x8xf32> + %0 = stablehlo.custom_call @xla.sdy.ManualComputation(%arg0) {called_computations = [@xla.sdy.manual_computation_body_0], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> + %1 = stablehlo.custom_call @xla.sdy.ManualComputation(%0) {called_computations = [@xla.sdy.manual_computation_body_1], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> + return %1 : tensor<8x8xf32> +} + +// CHECK-NOT: func @xla.sdy.manual_computation_body_3( +func.func @xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { + %0 = stablehlo.custom_call @xla.sdy.ManualComputation(%arg0) {called_computations = [@xla.sdy.manual_computation_body_2], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> + return %0 : tensor<2x8xf32> +} + +// CHECK-NOT: func @xla.sdy.manual_computation_body_2( +func.func @xla.sdy.manual_computation_body_2(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = stablehlo.multiply %arg0, %arg0 : tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} + +func.func @nested_shmaps(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { + // CHECK-NOT: xla.sdy.ManualComputation + // CHECK: %[[MAN_COMP_0:.*]] = sdy.manual_computation(%arg0) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{"a"}, {}]>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_1, [{"a"}, {}]>] + // CHECK-SAME{LITERAL}: manual_axes={"a"} + // CHECK-SAME: (%arg1: tensor<2x8xf32>) { + // CHECK-NEXT: %[[MAN_COMP_1:.*]] = sdy.manual_computation(%arg1) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_1, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + // CHECK-SAME: (%arg2: tensor<2x4xf32>) { + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> + // CHECK-NEXT: sdy.return %[[MULT]] : tensor<2x4xf32> + // CHECK-NEXT: } : (tensor<2x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: sdy.return %[[MAN_COMP_1]] : tensor<2x8xf32> + // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> + %0 = stablehlo.custom_call @xla.sdy.ManualComputation(%arg0) {called_computations = [@xla.sdy.manual_computation_body_3], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<4x8xf32>) -> tensor<4x8xf32> + return %0 : tensor<4x8xf32> +} + +func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { + // CHECK-NOT: xla.sdy.ManualComputation + // CHECK: %[[MAN_COMP_0:.*]] = sdy.manual_computation(%arg0) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{"a"}, {}]>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_1, [{"a"}, {}]>] + // CHECK-SAME{LITERAL}: manual_axes={"a"} + // CHECK-SAME: (%arg1: tensor<2x8xf32>) { + // CHECK-NEXT: %[[MAN_COMP_1:.*]] = sdy.manual_computation(%arg1) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_1, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + // CHECK-SAME: (%arg2: tensor<2x4xf32>) { + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> + // CHECK-NEXT: sdy.return %[[MULT]] : tensor<2x4xf32> + // CHECK-NEXT: } : (tensor<2x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[MAN_COMP_1]], %[[MAN_COMP_1]] : tensor<2x8xf32> + // CHECK-NEXT: sdy.return %[[ADD]] : tensor<2x8xf32> + // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> + %0 = stablehlo.custom_call @xla.sdy.ManualComputation(%arg0) {called_computations = [@xla.sdy.manual_computation_body_5], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<4x8xf32>) -> tensor<4x8xf32> + return %0 : tensor<4x8xf32> +} + +// CHECK-NOT: func @xla.sdy.manual_computation_body( +func.func @xla.sdy.manual_computation_body(%arg0: tensor<2x8xf32>, %arg1: tensor<8x32xf32>) -> tensor<2x32xf32> { + %0 = stablehlo.add %arg0, %arg0 : tensor<2x8xf32> + %1 = stablehlo.dot %0, %arg1 : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + %2 = "stablehlo.all_reduce"(%1) <{replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %3 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %3 : tensor + }) : (tensor<2x32xf32>) -> tensor<2x32xf32> + return %2 : tensor<2x32xf32> +} + +// CHECK-NOT: func @xla.sdy.manual_computation_body_0( +func.func @xla.sdy.manual_computation_body_0(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { + return %arg0 : tensor<2x8xf32> +} + +// CHECK-NOT: func @xla.sdy.manual_computation_body_1( +func.func @xla.sdy.manual_computation_body_1(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> { + return %arg0 : tensor<8x4xf32> +} + +// CHECK-NOT: func @xla.sdy.manual_computation_body_4( +func.func @xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = stablehlo.multiply %arg0, %arg0 : tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} + +// CHECK-NOT: func @xla.sdy.manual_computation_body_5( +func.func @xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { + %0 = stablehlo.custom_call @xla.sdy.ManualComputation(%arg0) {called_computations = [@xla.sdy.manual_computation_body_4], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> + %1 = stablehlo.add %0, %0 : tensor<2x8xf32> + return %1 : tensor<2x8xf32> +} diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir new file mode 100644 index 0000000000000..0d75aa379aafa --- /dev/null +++ b/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir @@ -0,0 +1,15 @@ +// RUN: sdy_opt %s -xla-sdy-round-trip-shard-map-import -split-input-file -verify-diagnostics + +sdy.mesh @mesh = <["a"=2]> + +func.func @using_same_body_func(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %0 = stablehlo.custom_call @xla.sdy.ManualComputation(%arg0) {called_computations = [@xla.sdy.manual_computation_body_0], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> + // expected-error @+2 {{'stablehlo.custom_call' op expected a unique FuncOp per @xla.sdy.ManualComputation custom call}} + // expected-error @+1 {{failed to legalize operation 'stablehlo.custom_call'}} + %1 = stablehlo.custom_call @xla.sdy.ManualComputation(%0) {called_computations = [@xla.sdy.manual_computation_body_0], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> + return %1 : tensor<8x8xf32> +} + +func.func @xla.sdy.manual_computation_body_0(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { + return %arg0 : tensor<2x8xf32> +} diff --git a/xla/service/spmd/shardy/utils.h b/xla/service/spmd/shardy/utils.h index f707ea16ee961..80194b3ca04c4 100644 --- a/xla/service/spmd/shardy/utils.h +++ b/xla/service/spmd/shardy/utils.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include "absl/log/check.h" +#include "absl/strings/escaping.h" +#include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -58,6 +61,27 @@ void removeFrontendAttribute(mlir::func::FuncOp funcOp, void loadAllRequiredDialects(mlir::MLIRContext* context); +// Parses `stringAttr` to an attribute of type `AttrTy`. +// +// NOTE: assumes `stringAttr` is of type `StringAttr`. +template +AttrTy parseStringAttr(mlir::Attribute stringAttr) { + std::string value; + std::string error; + CHECK(absl::CUnescape(mlir::cast(stringAttr).getValue(), + &value, &error)) + << error; + return mlir::cast( + mlir::parseAttribute(value, stringAttr.getContext())); +} + +// Parses `attrName` from `dictAttr` to an attribute of type `AttrTy`. +template +AttrTy parseStringAttr(mlir::DictionaryAttr dictAttr, + llvm::StringRef attrName) { + return parseStringAttr(dictAttr.get(attrName)); +} + } // namespace sdy } // namespace xla