From 56363ec3c50922d2ed1f9f234a7a566841d6f716 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Tue, 17 Sep 2024 12:03:04 -0700 Subject: [PATCH] Add support for parsing enum values in `ExecuteOptions::ApplyAllOptionOverrides`. This is needed to support parsing the `xla_gpu_disable_async_collectives` flag, which is an enum value. PiperOrigin-RevId: 675656859 --- xla/pjrt/pjrt_executable.cc | 46 ++++++++++++++++++++++++++++++++ xla/pjrt/pjrt_executable_test.cc | 2 +- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/xla/pjrt/pjrt_executable.cc b/xla/pjrt/pjrt_executable.cc index 68168124b2a568..2e443bb486954a 100644 --- a/xla/pjrt/pjrt_executable.cc +++ b/xla/pjrt/pjrt_executable.cc @@ -586,6 +586,31 @@ absl::Status CompileOptions::ApplyOption(const std::string& key, std::holds_alternative(value)) { reflection->SetDouble(&debug_options, xla_field, std::get(value)); return absl::OkStatus(); + } else if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { + LOG(INFO) << "Enum processing"; + if (std::holds_alternative(value)) { + LOG(INFO) << "Int value"; + if (xla_field->is_repeated()) { + reflection->AddEnumValue(&debug_options, xla_field, + std::get(value)); + } else { + reflection->SetEnumValue(&debug_options, xla_field, + std::get(value)); + } + } else { + LOG(INFO) << "Name: " << std::get(value); + auto enum_desc = xla_field->enum_type()->FindValueByName( + std::get(value)); + if (enum_desc != nullptr) { + LOG(INFO) << "Found nullptr"; + if (xla_field->is_repeated()) { + reflection->AddEnum(&debug_options, xla_field, enum_desc); + } else { + reflection->SetEnum(&debug_options, xla_field, enum_desc); + } + } + } + return absl::OkStatus(); } else { return InvalidArgument( "While setting option %s, '%s' is not a valid %s value.", key, @@ -636,6 +661,27 @@ absl::Status CompileOptions::ApplyOptionFromString( reflection->SetBool(&debug_options, field, bvalue); return absl::OkStatus(); } + } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { + int int_value; + if (absl::SimpleAtoi(value, &int_value)) { + if (field->is_repeated()) { + reflection->AddEnumValue(&debug_options, field, int_value); + } else { + reflection->SetEnumValue(&debug_options, field, int_value); + } + return absl::OkStatus(); + } else { + LOG(INFO) << "Name: " << value; + auto enum_desc = field->enum_type()->FindValueByName(value); + if (enum_desc != nullptr) { + LOG(INFO) << "Found nullptr"; + if (field->is_repeated()) { + reflection->AddEnum(&debug_options, field, enum_desc); + } else { + reflection->SetEnum(&debug_options, field, enum_desc); + } + } + } } return InvalidArgument( "While setting option %s, '%s' is not a valid %s value.", field->name(), diff --git a/xla/pjrt/pjrt_executable_test.cc b/xla/pjrt/pjrt_executable_test.cc index e3a0b3d09ba7f1..868ca80dec480c 100644 --- a/xla/pjrt/pjrt_executable_test.cc +++ b/xla/pjrt/pjrt_executable_test.cc @@ -93,7 +93,7 @@ TEST(ExecuteOptionsTest, SendRecvNotSupported) { "ExecuteOptions with send/recv calbacks is not serializable")); } -TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) { +TEST(ExecuteOptionsTest, ApplyOptionsCanParseStringsAndEnums) { using OptionOverride = std::variant; std::vector> env_override_options; env_override_options = {