From dd5044d684714a40b848185e4698eab05bc80a18 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 Apr 2024 05:45:14 -0700 Subject: [PATCH] Add sharding to added tuple for single func result PiperOrigin-RevId: 621830460 --- xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 52 ++++++++++++------- xla/translate/mhlo_to_hlo/tests/sharding.mlir | 20 ++++++- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 8911c7864ac83..3acf78148b834 100644 --- a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -70,6 +70,7 @@ limitations under the License. #include "xla/hlo/ir/dynamic_parameter_binding.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/mlir/utils/error_util.h" @@ -639,6 +640,30 @@ static void ExtractShardingsFromFunction( (*ret_shardings)[i] = xla::ConvertSharding(sharding.getValue()); } +// Creates a tuple sharding with the given shardings if at least one is present. +// +// Adds replicated shardings for any missing tuple shardings. +std::optional CreateTupleSharding( + llvm::ArrayRef> tuple_shardings) { + if (tuple_shardings.empty() || + !SomeOptionalShardingsAreSet(tuple_shardings)) { + return std::nullopt; + } + xla::OpSharding sharding; + sharding.set_type(xla::OpSharding::TUPLE); + for (const std::optional& tuple_sharding : tuple_shardings) { + if (tuple_sharding) { + *sharding.add_tuple_shardings() = *tuple_sharding; + } else { + xla::OpSharding fallback_sharding; + fallback_sharding.set_type(xla::OpSharding::REPLICATED); + *sharding.add_tuple_shardings() = fallback_sharding; + } + } + + return sharding; +} + namespace mlir { namespace { class ConvertToHloModule { @@ -3127,8 +3152,8 @@ LogicalResult ConvertToHloModule::Lower( // Construct the return value for the function. If there is a single value // returned, then return it directly, else create a tuple and return. unsigned num_return_values = inst->getNumOperands(); - const bool has_ret_shardings = - !ret_shardings.empty() && SomeOptionalShardingsAreSet(ret_shardings); + std::optional ret_tuple_sharding = + CreateTupleSharding(ret_shardings); if ((return_tuple_ && is_entry_function) || num_return_values != 1) { std::vector returns(num_return_values); for (OpOperand& ret : inst->getOpOperands()) { @@ -3138,7 +3163,7 @@ LogicalResult ConvertToHloModule::Lower( return failure(); returns[index] = operand; - if (!is_entry_function || !has_ret_shardings) continue; + if (!is_entry_function || !ret_tuple_sharding) continue; xla::Shape return_shape = xla::TypeToShape(ret.get().getType()); absl::StatusOr reshape = @@ -3152,29 +3177,16 @@ LogicalResult ConvertToHloModule::Lower( returns[index] = reshape.value(); } - if (has_ret_shardings) { - xla::OpSharding sharding; - sharding.set_type(xla::OpSharding::TUPLE); - for (auto& ret_sharding : ret_shardings) - if (ret_sharding) { - *sharding.add_tuple_shardings() = *ret_sharding; - } else { - xla::OpSharding fallback_sharding; - fallback_sharding.set_type(xla::OpSharding::REPLICATED); - *sharding.add_tuple_shardings() = fallback_sharding; - } - - builder->SetSharding(sharding); - } - + xla::XlaScopedShardingAssignment scoped_sharding(builder, + ret_tuple_sharding); *return_value = xla::Tuple(builder, returns); - builder->ClearSharding(); } else if (num_return_values == 1) { xla::XlaOp operand; if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst))) return failure(); - if (has_ret_shardings) { + if (ret_tuple_sharding) { + builder->SetSharding(*ret_tuple_sharding); auto tuple = Tuple(builder, {operand}); builder->SetSharding(*ret_shardings[0]); *return_value = GetTupleElement(tuple, 0); diff --git a/xla/translate/mhlo_to_hlo/tests/sharding.mlir b/xla/translate/mhlo_to_hlo/tests/sharding.mlir index a4a2650d61b69..76512c47537af 100644 --- a/xla/translate/mhlo_to_hlo/tests/sharding.mlir +++ b/xla/translate/mhlo_to_hlo/tests/sharding.mlir @@ -1,18 +1,34 @@ // RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s // CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[], Arg_1.2: f32[4]) -> f32[4,4] -func.func public @main(%arg0: tensor {mhlo.sharding = ""}, %arg1: tensor<4xf32> {mhlo.sharding = "\08\03\1A\01\02\22\02\00\01"}) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"}) { +func.func public @main(%arg0: tensor {mhlo.sharding = ""}, %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[1,2,1]0,1} "}) -> (tensor<4x4xf32> {mhlo.sharding = "{devices=[2,1]0,1}"}) { // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2]0,1} // CHECK-NEXT: %Arg_0.1 = f32[] parameter(0), sharding={replicated} %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4xf32> %1 = mhlo.multiply %arg1, %0 : tensor<4xf32> %2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4x4xf32> - // CHECK: ROOT {{.*}}, sharding={devices=[2,1]0,1} + // CHECK: %tuple.6 = (f32[4,4]) tuple(f32[4,4] %broadcast.5), sharding={{\{}}{devices=[2,1]0,1}} + // CHECK-NEXT: ROOT %get-tuple-element.7 = f32[4,4] get-tuple-element((f32[4,4]) %tuple.6), index=0, sharding={devices=[2,1]0,1} func.return %2 : tensor<4x4xf32> } // ----- +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[], Arg_1.2: f32[4]) -> (f32[4,4]) { +func.func public @main(%arg0: tensor {mhlo.sharding = ""}, %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[1,2,1]0,1} "}) -> (tuple> {mhlo.sharding = "{{devices=[2,1]0,1}}"}) { + // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2]0,1} + // CHECK-NEXT: %Arg_0.1 = f32[] parameter(0), sharding={replicated} + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4xf32> + %1 = mhlo.multiply %arg1, %0 : tensor<4xf32> + %2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4x4xf32> + %3 = mhlo.tuple %2 : tuple> + // CHECK: %tuple.7 = (f32[4,4]) tuple(f32[4,4] %broadcast.5), sharding={{\{}}{{\{}}{devices=[2,1]0,1}}} + // CHECK-NEXT: ROOT %get-tuple-element.8 = f32[4,4] get-tuple-element((f32[4,4]) %tuple.9), index=0, sharding={{\{}}{devices=[2,1]0,1}} + func.return %3 : tuple> +} + +// ----- + // CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[5,8,128]) -> f32[5,8,128] func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) -> (tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) { // CHECK-NEXT: %Arg_0.1 = f32[5,8,128] parameter(0), sharding={devices=[1,2,1]0,1}