diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 22ee30067c89fc..12e99a8d1007a4 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -306,8 +306,12 @@ absl::StatusOr FunctionalHloRunner::CreateCompileOptions( build_options.set_process_index(task_id); build_options.set_process_count(num_nodes); build_options.set_key_value_store(kv_store); - if (raw_options.spmd_mode == SpmdMode::kUseSpmdPartitioning) { + if (raw_options.spmd_mode == SpmdMode::kUseSpmdPartitioning || + raw_options.spmd_mode == SpmdMode::kUseShardyPartitioning) { build_options.set_use_spmd_partitioning(true); + if (raw_options.spmd_mode == SpmdMode::kUseShardyPartitioning) { + build_options.set_use_shardy_partitioner(true); + } } if (!build_options.has_device_assignment() && !raw_options.num_slices.has_value()) { @@ -400,6 +404,8 @@ FunctionalHloRunner::CreateExecutableBuildOptionsFromExecutionOptions( build_options.set_num_partitions(execution_options.num_partitions()); build_options.set_use_spmd_partitioning( execution_options.use_spmd_partitioning()); + build_options.set_use_shardy_partitioner( + execution_options.use_shardy_partitioner()); build_options.set_use_auto_spmd_partitioning( execution_options.use_auto_spmd_partitioning()); build_options.set_deduplicate_hlo(execution_options.deduplicate_hlo()); diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index 7bfa0c1dfd88eb..79d986148b9a37 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -96,7 +96,11 @@ class FunctionalHloRunner { kStandardCompile }; - enum class SpmdMode { kUseSpmdPartitioning, kNotUseSpmdPartitioning }; + enum class SpmdMode : int8_t { + kUseSpmdPartitioning, // Use the GSPMD partitioner for partitioning. + kUseShardyPartitioning, // Use the Shardy partitioner for partitioning. + kNotUseSpmdPartitioning // Do not perform partitioning. + }; enum class SpmdPartitionedMode { kIsSpmdPartitionedModule,