From 2cac71b5414940f729248185af7e4922dc2b7159 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 5 Sep 2024 15:15:40 -0700 Subject: [PATCH] Remove dependency on `-hlo-import-all-computations` flag when exporting `entry_computation_layout`. PiperOrigin-RevId: 671527962 --- xla/translate/hlo_to_mhlo/hlo_module_importer.cc | 7 ++++--- .../hlo_to_mhlo/tests/import_entry_computation_layout.hlo | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index 8ad9d3844438e7..a2b348c5c361a4 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -55,6 +55,8 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { ImportCrossProgramPrefetches(hlo_module, module, flatten_computation_args_result_, builder_); + if (!flatten_computation_args_result_) + ImportEntryComputationLayoutAndTiles(hlo_module, module, builder_); ImportFrontendAttributes(hlo_module, module, builder_); ImportInputOutputAlias(hlo_module, module, builder_); ImportIsDynamic(hlo_module, module, builder_); @@ -71,7 +73,7 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { return HloFunctionImporter::ImportAsFunc( *hlo_module.entry_computation(), symbol_table_, &function_map_, &builder_, - /*is_main*/ true, flatten_computation_args_result_) + /*is_main=*/true, flatten_computation_args_result_) .status(); auto* module_entry_computation = hlo_module.entry_computation(); @@ -79,11 +81,10 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc( *computation, symbol_table_, &function_map_, &builder_, - /*is_main*/ computation == module_entry_computation, + /*is_main=*/computation == module_entry_computation, flatten_computation_args_result_) .status()); - ImportEntryComputationLayoutAndTiles(hlo_module, module, builder_); return absl::OkStatus(); } diff --git a/xla/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo b/xla/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo index 6663e052502cce..e279a2754bc0e4 100644 --- a/xla/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo +++ b/xla/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo @@ -1,4 +1,4 @@ -// RUN: xla-translate -hlo-text-to-mlir-hlo -with-layouts -print-layouts -hlo-import-all-computations -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: xla-translate -hlo-text-to-mlir-hlo -with-layouts -print-layouts -split-input-file -verify-diagnostics %s -o - | FileCheck %s HloModule entry, entry_computation_layout={( f32[2,3,4]{0,1,2:T(2,128)},