diff --git a/xla/pjrt/pjrt_executable.cc b/xla/pjrt/pjrt_executable.cc index 68168124b2a568..79fea677871222 100644 --- a/xla/pjrt/pjrt_executable.cc +++ b/xla/pjrt/pjrt_executable.cc @@ -586,6 +586,27 @@ absl::Status CompileOptions::ApplyOption(const std::string& key, std::holds_alternative(value)) { reflection->SetDouble(&debug_options, xla_field, std::get(value)); return absl::OkStatus(); + } else if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { + if (std::holds_alternative(value)) { + if (xla_field->is_repeated()) { + reflection->AddEnumValue(&debug_options, xla_field, + std::get(value)); + } else { + reflection->SetEnumValue(&debug_options, xla_field, + std::get(value)); + } + } else { + auto enum_desc = xla_field->enum_type()->FindValueByName( + std::get(value)); + if (enum_desc != nullptr) { + if (xla_field->is_repeated()) { + reflection->AddEnum(&debug_options, xla_field, enum_desc); + } else { + reflection->SetEnum(&debug_options, xla_field, enum_desc); + } + } + } + return absl::OkStatus(); } else { return InvalidArgument( "While setting option %s, '%s' is not a valid %s value.", key, @@ -636,6 +657,25 @@ absl::Status CompileOptions::ApplyOptionFromString( reflection->SetBool(&debug_options, field, bvalue); return absl::OkStatus(); } + } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { + int int_value; + if (absl::SimpleAtoi(value, &int_value)) { + if (field->is_repeated()) { + reflection->AddEnumValue(&debug_options, field, int_value); + } else { + reflection->SetEnumValue(&debug_options, field, int_value); + } + return absl::OkStatus(); + } else { + auto enum_desc = field->enum_type()->FindValueByName(value); + if (enum_desc != nullptr) { + if (field->is_repeated()) { + reflection->AddEnum(&debug_options, field, enum_desc); + } else { + reflection->SetEnum(&debug_options, field, enum_desc); + } + } + } } return InvalidArgument( "While setting option %s, '%s' is not a valid %s value.", field->name(), diff --git a/xla/pjrt/pjrt_executable_test.cc b/xla/pjrt/pjrt_executable_test.cc index e3a0b3d09ba7f1..8b8653074ea480 100644 --- a/xla/pjrt/pjrt_executable_test.cc +++ b/xla/pjrt/pjrt_executable_test.cc @@ -93,15 +93,26 @@ TEST(ExecuteOptionsTest, SendRecvNotSupported) { "ExecuteOptions with send/recv calbacks is not serializable")); } -TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) { +TEST(ExecuteOptionsTest, ApplyOptionsCanParseStringsAndEnums) { using OptionOverride = std::variant; std::vector> env_override_options; env_override_options = { {"xla_gpu_use_runtime_fusion", std::string("True")}, {"xla_gpu_graph_min_graph_size", std::string("2")}, + {"xla_gpu_disable_async_collectives", std::string("2")}, + {"xla_gpu_disable_async_collectives", + std::string("COLLECTIVEBROADCAST")}, {"xla_gpu_redzone_scratch_max_megabytes", std::string("3400")}, {"xla_gpu_auto_spmd_partitioning_memory_budget_ratio", 0.9}, {"xla_gpu_pgle_profile_file_or_directory_path", std::string("abc")}}; + // env_override_options = { + // {"xla_gpu_use_runtime_fusion", std::string("True")}, + // {"xla_gpu_graph_min_graph_size", std::string("2")}, + // {"xla_gpu_redzone_scratch_max_megabytes", std::string("3400")}, + // {"xla_gpu_auto_spmd_partitioning_memory_budget_ratio", 0.9}, + // {"xla_gpu_pgle_profile_file_or_directory_path", std::string("abc")}, + // {"xla_gpu_disable_async_collectives", 1}, + // {"xla_gpu_disable_async_collectives", std::string("2")}, CompileOptions src; src.env_option_overrides = env_override_options; auto s = src.ApplyAllOptionOverrides(); @@ -112,6 +123,10 @@ TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) { EXPECT_FLOAT_EQ( debug_options.xla_gpu_auto_spmd_partitioning_memory_budget_ratio(), 0.9); EXPECT_EQ(debug_options.xla_gpu_pgle_profile_file_or_directory_path(), "abc"); + // EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives().size(), 3); + // EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives()[0], 1); + // EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives()[1], 2); + // EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives()[2], 4); } TEST(CompiledMemoryStatsTest, Serialization) {