Skip to content

Commit

Permalink
Add sharding to added tuple for single func result
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621830460
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Apr 4, 2024
1 parent 22257bf commit dd5044d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 22 deletions.
52 changes: 32 additions & 20 deletions xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ limitations under the License.
#include "xla/hlo/ir/dynamic_parameter_binding.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/mlir/utils/error_util.h"
Expand Down Expand Up @@ -639,6 +640,30 @@ static void ExtractShardingsFromFunction(
(*ret_shardings)[i] = xla::ConvertSharding(sharding.getValue());
}

// Creates a tuple sharding with the given shardings if at least one is present.
//
// Adds replicated shardings for any missing tuple shardings.
std::optional<xla::OpSharding> CreateTupleSharding(
llvm::ArrayRef<std::optional<xla::OpSharding>> tuple_shardings) {
if (tuple_shardings.empty() ||
!SomeOptionalShardingsAreSet(tuple_shardings)) {
return std::nullopt;
}
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::TUPLE);
for (const std::optional<xla::OpSharding>& tuple_sharding : tuple_shardings) {
if (tuple_sharding) {
*sharding.add_tuple_shardings() = *tuple_sharding;
} else {
xla::OpSharding fallback_sharding;
fallback_sharding.set_type(xla::OpSharding::REPLICATED);
*sharding.add_tuple_shardings() = fallback_sharding;
}
}

return sharding;
}

namespace mlir {
namespace {
class ConvertToHloModule {
Expand Down Expand Up @@ -3127,8 +3152,8 @@ LogicalResult ConvertToHloModule::Lower(
// Construct the return value for the function. If there is a single value
// returned, then return it directly, else create a tuple and return.
unsigned num_return_values = inst->getNumOperands();
const bool has_ret_shardings =
!ret_shardings.empty() && SomeOptionalShardingsAreSet(ret_shardings);
std::optional<xla::OpSharding> ret_tuple_sharding =
CreateTupleSharding(ret_shardings);
if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
std::vector<xla::XlaOp> returns(num_return_values);
for (OpOperand& ret : inst->getOpOperands()) {
Expand All @@ -3138,7 +3163,7 @@ LogicalResult ConvertToHloModule::Lower(
return failure();

returns[index] = operand;
if (!is_entry_function || !has_ret_shardings) continue;
if (!is_entry_function || !ret_tuple_sharding) continue;

xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
absl::StatusOr<xla::XlaOp> reshape =
Expand All @@ -3152,29 +3177,16 @@ LogicalResult ConvertToHloModule::Lower(
returns[index] = reshape.value();
}

if (has_ret_shardings) {
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::TUPLE);
for (auto& ret_sharding : ret_shardings)
if (ret_sharding) {
*sharding.add_tuple_shardings() = *ret_sharding;
} else {
xla::OpSharding fallback_sharding;
fallback_sharding.set_type(xla::OpSharding::REPLICATED);
*sharding.add_tuple_shardings() = fallback_sharding;
}

builder->SetSharding(sharding);
}

xla::XlaScopedShardingAssignment scoped_sharding(builder,
ret_tuple_sharding);
*return_value = xla::Tuple(builder, returns);
builder->ClearSharding();
} else if (num_return_values == 1) {
xla::XlaOp operand;
if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
return failure();

if (has_ret_shardings) {
if (ret_tuple_sharding) {
builder->SetSharding(*ret_tuple_sharding);
auto tuple = Tuple(builder, {operand});
builder->SetSharding(*ret_shardings[0]);
*return_value = GetTupleElement(tuple, 0);
Expand Down
20 changes: 18 additions & 2 deletions xla/translate/mhlo_to_hlo/tests/sharding.mlir
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s

// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[], Arg_1.2: f32[4]) -> f32[4,4]
func.func public @main(%arg0: tensor<f32> {mhlo.sharding = ""}, %arg1: tensor<4xf32> {mhlo.sharding = "\08\03\1A\01\02\22\02\00\01"}) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"}) {
func.func public @main(%arg0: tensor<f32> {mhlo.sharding = ""}, %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[1,2,1]0,1} "}) -> (tensor<4x4xf32> {mhlo.sharding = "{devices=[2,1]0,1}"}) {
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2]0,1}
// CHECK-NEXT: %Arg_0.1 = f32[] parameter(0), sharding={replicated}
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<4xf32>
%1 = mhlo.multiply %arg1, %0 : tensor<4xf32>
%2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4x4xf32>
// CHECK: ROOT {{.*}}, sharding={devices=[2,1]0,1}
// CHECK: %tuple.6 = (f32[4,4]) tuple(f32[4,4] %broadcast.5), sharding={{\{}}{devices=[2,1]0,1}}
// CHECK-NEXT: ROOT %get-tuple-element.7 = f32[4,4] get-tuple-element((f32[4,4]) %tuple.6), index=0, sharding={devices=[2,1]0,1}
func.return %2 : tensor<4x4xf32>
}

// -----

// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[], Arg_1.2: f32[4]) -> (f32[4,4]) {
func.func public @main(%arg0: tensor<f32> {mhlo.sharding = ""}, %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[1,2,1]0,1} "}) -> (tuple<tensor<4x4xf32>> {mhlo.sharding = "{{devices=[2,1]0,1}}"}) {
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2]0,1}
// CHECK-NEXT: %Arg_0.1 = f32[] parameter(0), sharding={replicated}
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<4xf32>
%1 = mhlo.multiply %arg1, %0 : tensor<4xf32>
%2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4x4xf32>
%3 = mhlo.tuple %2 : tuple<tensor<4x4xf32>>
// CHECK: %tuple.7 = (f32[4,4]) tuple(f32[4,4] %broadcast.5), sharding={{\{}}{{\{}}{devices=[2,1]0,1}}}
// CHECK-NEXT: ROOT %get-tuple-element.8 = f32[4,4] get-tuple-element((f32[4,4]) %tuple.9), index=0, sharding={{\{}}{devices=[2,1]0,1}}
func.return %3 : tuple<tensor<4x4xf32>>
}

// -----

// CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[5,8,128]) -> f32[5,8,128]
func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) -> (tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) {
// CHECK-NEXT: %Arg_0.1 = f32[5,8,128] parameter(0), sharding={devices=[1,2,1]0,1}
Expand Down

0 comments on commit dd5044d

Please sign in to comment.