Skip to content

Commit

Permalink
Add support for -hlo-flatten-computation-args-result for importing …
Browse files Browse the repository at this point in the history
…`entry_computation_layout`.

PiperOrigin-RevId: 671527962
  • Loading branch information
ghpvnist authored and Google-ML-Automation committed Sep 18, 2024
1 parent bfaba22 commit a681387
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
77 changes: 57 additions & 20 deletions xla/translate/hlo_to_mhlo/module_attributes_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,34 @@ mlir::ArrayAttr ConvertCrossProgramPrefetches(

void ImportEntryComputationParameterLayoutAndTiles(
const xla::HloModule& hlo_module, mlir::ModuleOp module,
const ComputationLayout& computation_layout, mlir::Builder builder) {
const ComputationLayout& computation_layout,
bool flatten_computation_args_result, mlir::Builder builder) {
llvm::SmallVector<mlir::Attribute> parameter_layouts;
llvm::SmallVector<mlir::Attribute> parameter_tiles;
for (auto& layout : computation_layout.parameter_layouts()) {
if (layout.shape().IsTuple()) {
if (flatten_computation_args_result) {
for (auto& parameter_layout : computation_layout.parameter_layouts()) {
xla::ShapeUtil::ForEachLeafShape(
parameter_layout.shape(),
[&](const xla::Shape& subshape, const xla::ShapeIndex& index) {
std::pair<mlir::Attribute, mlir::ArrayAttr> layout_attrs =
GetLayoutAttribute(builder, subshape);
parameter_layouts.push_back(layout_attrs.first);
parameter_tiles.push_back(layout_attrs.second);
});
}
module->setAttr(kEntryComputationParameterLayouts,
builder.getArrayAttr({parameter_layouts}));
module->setAttr(kEntryComputationParameterTiles,
builder.getArrayAttr({parameter_tiles}));
return;
}

for (auto& parameter_layout : computation_layout.parameter_layouts()) {
if (parameter_layout.shape().IsTuple()) {
llvm::SmallVector<mlir::Attribute> tuple_element_parameter_layouts;
llvm::SmallVector<mlir::Attribute> tuple_element_parameter_tiles;
for (auto& tuple_element_shape : layout.shape().tuple_shapes()) {
for (auto& tuple_element_shape :
parameter_layout.shape().tuple_shapes()) {
std::pair<mlir::Attribute, mlir::Attribute> layout_attrs =
GetLayoutAttribute(builder, tuple_element_shape);
tuple_element_parameter_layouts.push_back(layout_attrs.first);
Expand All @@ -120,7 +140,7 @@ void ImportEntryComputationParameterLayoutAndTiles(
builder.getArrayAttr({tuple_element_parameter_tiles}));
} else {
std::pair<mlir::Attribute, mlir::ArrayAttr> layout_attrs =
GetLayoutAttribute(builder, layout.shape());
GetLayoutAttribute(builder, parameter_layout.shape());
parameter_layouts.push_back(layout_attrs.first);
parameter_tiles.push_back(layout_attrs.second);
}
Expand All @@ -133,10 +153,23 @@ void ImportEntryComputationParameterLayoutAndTiles(

void ImportEntryComputationResultLayoutAndTiles(
const xla::HloModule& hlo_module, mlir::ModuleOp module,
const ComputationLayout& computation_layout, mlir::Builder builder) {
const ComputationLayout& computation_layout,
bool flatten_computation_args_result, mlir::Builder builder) {
llvm::SmallVector<mlir::Attribute> result_layouts;
llvm::SmallVector<mlir::Attribute> result_tiles;
if (flatten_computation_args_result) {
xla::ShapeUtil::ForEachLeafShape(
computation_layout.result_layout().shape(),
[&](const xla::Shape& subshape, const xla::ShapeIndex& index) {
std::pair<mlir::Attribute, mlir::ArrayAttr> layout_attrs =
GetLayoutAttribute(builder, subshape);
result_layouts.push_back(layout_attrs.first);
result_tiles.push_back(layout_attrs.second);
});
return;
}

if (computation_layout.result_layout().shape().IsTuple()) {
llvm::SmallVector<mlir::Attribute> result_layouts;
llvm::SmallVector<mlir::Attribute> result_tiles;
for (auto& tuple_element_layout :
computation_layout.result_layout().shape().tuple_shapes()) {
std::pair<mlir::Attribute, mlir::Attribute> layout_attrs =
Expand All @@ -149,15 +182,16 @@ void ImportEntryComputationResultLayoutAndTiles(
builder.getArrayAttr({builder.getArrayAttr(result_layouts)}));
module->setAttr(kEntryComputationResultTiles,
builder.getArrayAttr({builder.getArrayAttr(result_tiles)}));
} else {
std::pair<mlir::Attribute, mlir::ArrayAttr> layout_attrs =
GetLayoutAttribute(builder, computation_layout.result_layout().shape(),
computation_layout.result_layout().layout());
module->setAttr(kEntryComputationResultLayout,
builder.getArrayAttr({layout_attrs.first}));
module->setAttr(kEntryComputationResultTiles,
builder.getArrayAttr({layout_attrs.second}));
return;
}

std::pair<mlir::Attribute, mlir::ArrayAttr> layout_attrs =
GetLayoutAttribute(builder, computation_layout.result_layout().shape(),
computation_layout.result_layout().layout());
module->setAttr(kEntryComputationResultLayout,
builder.getArrayAttr({layout_attrs.first}));
module->setAttr(kEntryComputationResultTiles,
builder.getArrayAttr({layout_attrs.second}));
}

} // namespace
Expand All @@ -175,6 +209,7 @@ void ImportCrossProgramPrefetches(const xla::HloModule& hlo_module,

void ImportEntryComputationLayoutAndTiles(const xla::HloModule& hlo_module,
mlir::ModuleOp module,
bool flatten_computation_args_result,
mlir::Builder builder) {
const auto& computation_layout = hlo_module.entry_computation_layout();
if (!computation_layout.LayoutIsSet()) return;
Expand All @@ -186,12 +221,14 @@ void ImportEntryComputationLayoutAndTiles(const xla::HloModule& hlo_module,
[](const ShapeLayout& shape) {
return HasCustomLayout(shape.shape());
})) {
ImportEntryComputationParameterLayoutAndTiles(hlo_module, module,
computation_layout, builder);
ImportEntryComputationParameterLayoutAndTiles(
hlo_module, module, computation_layout, flatten_computation_args_result,
builder);
}
if (HasCustomLayout(computation_layout.result_layout().shape())) {
ImportEntryComputationResultLayoutAndTiles(hlo_module, module,
computation_layout, builder);
ImportEntryComputationResultLayoutAndTiles(
hlo_module, module, computation_layout, flatten_computation_args_result,
builder);
}
}

Expand Down
1 change: 1 addition & 0 deletions xla/translate/hlo_to_mhlo/module_attributes_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void ImportCrossProgramPrefetches(const HloModule& hlo_module,

void ImportEntryComputationLayoutAndTiles(const HloModule& hlo_module,
mlir::ModuleOp module,
bool flatten_computation_args_result,
mlir::Builder builder);

void ImportFrontendAttributes(const HloModule& hlo_module,
Expand Down

0 comments on commit a681387

Please sign in to comment.