From 76407891d558975a2f6d87c8f00cc48b719a95b9 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Wed, 18 Sep 2024 16:19:03 -0700 Subject: [PATCH] Add support for parsing enum values in `ExecuteOptions::ApplyAllOptionOverrides`. This is needed to support parsing the `xla_gpu_disable_async_collectives` flag, which is an enum value. PiperOrigin-RevId: 676172348 --- xla/pjrt/pjrt_executable.cc | 40 ++++++++++++++++++++++++++++++++ xla/pjrt/pjrt_executable_test.cc | 5 +++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/xla/pjrt/pjrt_executable.cc b/xla/pjrt/pjrt_executable.cc index 68168124b2a56..79fea67787122 100644 --- a/xla/pjrt/pjrt_executable.cc +++ b/xla/pjrt/pjrt_executable.cc @@ -586,6 +586,27 @@ absl::Status CompileOptions::ApplyOption(const std::string& key, std::holds_alternative(value)) { reflection->SetDouble(&debug_options, xla_field, std::get(value)); return absl::OkStatus(); + } else if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { + if (std::holds_alternative(value)) { + if (xla_field->is_repeated()) { + reflection->AddEnumValue(&debug_options, xla_field, + std::get(value)); + } else { + reflection->SetEnumValue(&debug_options, xla_field, + std::get(value)); + } + } else { + auto enum_desc = xla_field->enum_type()->FindValueByName( + std::get(value)); + if (enum_desc != nullptr) { + if (xla_field->is_repeated()) { + reflection->AddEnum(&debug_options, xla_field, enum_desc); + } else { + reflection->SetEnum(&debug_options, xla_field, enum_desc); + } + } + } + return absl::OkStatus(); } else { return InvalidArgument( "While setting option %s, '%s' is not a valid %s value.", key, @@ -636,6 +657,25 @@ absl::Status CompileOptions::ApplyOptionFromString( reflection->SetBool(&debug_options, field, bvalue); return absl::OkStatus(); } + } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { + int int_value; + if (absl::SimpleAtoi(value, &int_value)) { + if (field->is_repeated()) { + reflection->AddEnumValue(&debug_options, field, int_value); + } else { + reflection->SetEnumValue(&debug_options, field, int_value); + } + return absl::OkStatus(); + } else { + auto enum_desc = field->enum_type()->FindValueByName(value); + if (enum_desc != nullptr) { + if (field->is_repeated()) { + reflection->AddEnum(&debug_options, field, enum_desc); + } else { + reflection->SetEnum(&debug_options, field, enum_desc); + } + } + } } return InvalidArgument( "While setting option %s, '%s' is not a valid %s value.", field->name(), diff --git a/xla/pjrt/pjrt_executable_test.cc b/xla/pjrt/pjrt_executable_test.cc index e3a0b3d09ba7f..72c0da6f04bda 100644 --- a/xla/pjrt/pjrt_executable_test.cc +++ b/xla/pjrt/pjrt_executable_test.cc @@ -93,12 +93,13 @@ TEST(ExecuteOptionsTest, SendRecvNotSupported) { "ExecuteOptions with send/recv calbacks is not serializable")); } -TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) { +TEST(ExecuteOptionsTest, ApplyOptionsCanParseStringsAndEnums) { using OptionOverride = std::variant; std::vector> env_override_options; env_override_options = { {"xla_gpu_use_runtime_fusion", std::string("True")}, {"xla_gpu_graph_min_graph_size", std::string("2")}, + {"xla_gpu_disable_async_collectives", std::string("2")}, {"xla_gpu_redzone_scratch_max_megabytes", std::string("3400")}, {"xla_gpu_auto_spmd_partitioning_memory_budget_ratio", 0.9}, {"xla_gpu_pgle_profile_file_or_directory_path", std::string("abc")}}; @@ -112,6 +113,8 @@ TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) { EXPECT_FLOAT_EQ( debug_options.xla_gpu_auto_spmd_partitioning_memory_budget_ratio(), 0.9); EXPECT_EQ(debug_options.xla_gpu_pgle_profile_file_or_directory_path(), "abc"); + EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives().size(), 1); + EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives()[0], 2); } TEST(CompiledMemoryStatsTest, Serialization) {