diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index 8ad9d3844438e..06888b5d01360 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -83,7 +83,8 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { flatten_computation_args_result_) .status()); - ImportEntryComputationLayoutAndTiles(hlo_module, module, builder_); + ImportEntryComputationLayoutAndTiles( + hlo_module, module, flatten_computation_args_result_, builder_); return absl::OkStatus(); } diff --git a/xla/translate/hlo_to_mhlo/module_attributes_importer.cc b/xla/translate/hlo_to_mhlo/module_attributes_importer.cc index a60499f1b9bdf..592a127965325 100644 --- a/xla/translate/hlo_to_mhlo/module_attributes_importer.cc +++ b/xla/translate/hlo_to_mhlo/module_attributes_importer.cc @@ -101,14 +101,34 @@ mlir::ArrayAttr ConvertCrossProgramPrefetches( void ImportEntryComputationParameterLayoutAndTiles( const xla::HloModule& hlo_module, mlir::ModuleOp module, - const ComputationLayout& computation_layout, mlir::Builder builder) { + const ComputationLayout& computation_layout, + bool flatten_computation_args_result, mlir::Builder builder) { llvm::SmallVector parameter_layouts; llvm::SmallVector parameter_tiles; - for (auto& layout : computation_layout.parameter_layouts()) { - if (layout.shape().IsTuple()) { + if (flatten_computation_args_result) { + for (auto& parameter_layout : computation_layout.parameter_layouts()) { + xla::ShapeUtil::ForEachLeafShape( + parameter_layout.shape(), + [&](const xla::Shape& subshape, const xla::ShapeIndex& index) { + std::pair layout_attrs = + GetLayoutAttribute(builder, subshape); + parameter_layouts.push_back(layout_attrs.first); + parameter_tiles.push_back(layout_attrs.second); + }); + } + module->setAttr(kEntryComputationParameterLayouts, + builder.getArrayAttr({parameter_layouts})); + module->setAttr(kEntryComputationParameterTiles, + builder.getArrayAttr({parameter_tiles})); + return; + } + + for (auto& parameter_layout : computation_layout.parameter_layouts()) { + if (parameter_layout.shape().IsTuple()) { llvm::SmallVector tuple_element_parameter_layouts; llvm::SmallVector tuple_element_parameter_tiles; - for (auto& tuple_element_shape : layout.shape().tuple_shapes()) { + for (auto& tuple_element_shape : + parameter_layout.shape().tuple_shapes()) { std::pair layout_attrs = GetLayoutAttribute(builder, tuple_element_shape); tuple_element_parameter_layouts.push_back(layout_attrs.first); @@ -120,7 +140,7 @@ void ImportEntryComputationParameterLayoutAndTiles( builder.getArrayAttr({tuple_element_parameter_tiles})); } else { std::pair layout_attrs = - GetLayoutAttribute(builder, layout.shape()); + GetLayoutAttribute(builder, parameter_layout.shape()); parameter_layouts.push_back(layout_attrs.first); parameter_tiles.push_back(layout_attrs.second); } @@ -133,10 +153,23 @@ void ImportEntryComputationParameterLayoutAndTiles( void ImportEntryComputationResultLayoutAndTiles( const xla::HloModule& hlo_module, mlir::ModuleOp module, - const ComputationLayout& computation_layout, mlir::Builder builder) { + const ComputationLayout& computation_layout, + bool flatten_computation_args_result, mlir::Builder builder) { + llvm::SmallVector result_layouts; + llvm::SmallVector result_tiles; + if (flatten_computation_args_result) { + xla::ShapeUtil::ForEachLeafShape( + computation_layout.result_layout().shape(), + [&](const xla::Shape& subshape, const xla::ShapeIndex& index) { + std::pair layout_attrs = + GetLayoutAttribute(builder, subshape); + result_layouts.push_back(layout_attrs.first); + result_tiles.push_back(layout_attrs.second); + }); + return; + } + if (computation_layout.result_layout().shape().IsTuple()) { - llvm::SmallVector result_layouts; - llvm::SmallVector result_tiles; for (auto& tuple_element_layout : computation_layout.result_layout().shape().tuple_shapes()) { std::pair layout_attrs = @@ -149,15 +182,16 @@ void ImportEntryComputationResultLayoutAndTiles( builder.getArrayAttr({builder.getArrayAttr(result_layouts)})); module->setAttr(kEntryComputationResultTiles, builder.getArrayAttr({builder.getArrayAttr(result_tiles)})); - } else { - std::pair layout_attrs = - GetLayoutAttribute(builder, computation_layout.result_layout().shape(), - computation_layout.result_layout().layout()); - module->setAttr(kEntryComputationResultLayout, - builder.getArrayAttr({layout_attrs.first})); - module->setAttr(kEntryComputationResultTiles, - builder.getArrayAttr({layout_attrs.second})); + return; } + + std::pair layout_attrs = + GetLayoutAttribute(builder, computation_layout.result_layout().shape(), + computation_layout.result_layout().layout()); + module->setAttr(kEntryComputationResultLayout, + builder.getArrayAttr({layout_attrs.first})); + module->setAttr(kEntryComputationResultTiles, + builder.getArrayAttr({layout_attrs.second})); } } // namespace @@ -175,6 +209,7 @@ void ImportCrossProgramPrefetches(const xla::HloModule& hlo_module, void ImportEntryComputationLayoutAndTiles(const xla::HloModule& hlo_module, mlir::ModuleOp module, + bool flatten_computation_args_result, mlir::Builder builder) { const auto& computation_layout = hlo_module.entry_computation_layout(); if (!computation_layout.LayoutIsSet()) return; @@ -186,12 +221,14 @@ void ImportEntryComputationLayoutAndTiles(const xla::HloModule& hlo_module, [](const ShapeLayout& shape) { return HasCustomLayout(shape.shape()); })) { - ImportEntryComputationParameterLayoutAndTiles(hlo_module, module, - computation_layout, builder); + ImportEntryComputationParameterLayoutAndTiles( + hlo_module, module, computation_layout, flatten_computation_args_result, + builder); } if (HasCustomLayout(computation_layout.result_layout().shape())) { - ImportEntryComputationResultLayoutAndTiles(hlo_module, module, - computation_layout, builder); + ImportEntryComputationResultLayoutAndTiles( + hlo_module, module, computation_layout, flatten_computation_args_result, + builder); } } diff --git a/xla/translate/hlo_to_mhlo/module_attributes_importer.h b/xla/translate/hlo_to_mhlo/module_attributes_importer.h index bd4580e5d315a..945db534c86a0 100644 --- a/xla/translate/hlo_to_mhlo/module_attributes_importer.h +++ b/xla/translate/hlo_to_mhlo/module_attributes_importer.h @@ -32,6 +32,7 @@ void ImportCrossProgramPrefetches(const HloModule& hlo_module, void ImportEntryComputationLayoutAndTiles(const HloModule& hlo_module, mlir::ModuleOp module, + bool flatten_computation_args_result, mlir::Builder builder); void ImportFrontendAttributes(const HloModule& hlo_module,