diff --git a/xla/pjrt/pjrt_executable.cc b/xla/pjrt/pjrt_executable.cc index 68168124b2a568..2e443bb486954a 100644 --- a/xla/pjrt/pjrt_executable.cc +++ b/xla/pjrt/pjrt_executable.cc @@ -586,6 +586,31 @@ absl::Status CompileOptions::ApplyOption(const std::string& key, std::holds_alternative(value)) { reflection->SetDouble(&debug_options, xla_field, std::get(value)); return absl::OkStatus(); + } else if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { + LOG(INFO) << "Enum processing"; + if (std::holds_alternative(value)) { + LOG(INFO) << "Int value"; + if (xla_field->is_repeated()) { + reflection->AddEnumValue(&debug_options, xla_field, + std::get(value)); + } else { + reflection->SetEnumValue(&debug_options, xla_field, + std::get(value)); + } + } else { + LOG(INFO) << "Name: " << std::get(value); + auto enum_desc = xla_field->enum_type()->FindValueByName( + std::get(value)); + if (enum_desc != nullptr) { + LOG(INFO) << "Found nullptr"; + if (xla_field->is_repeated()) { + reflection->AddEnum(&debug_options, xla_field, enum_desc); + } else { + reflection->SetEnum(&debug_options, xla_field, enum_desc); + } + } + } + return absl::OkStatus(); } else { return InvalidArgument( "While setting option %s, '%s' is not a valid %s value.", key, @@ -636,6 +661,27 @@ absl::Status CompileOptions::ApplyOptionFromString( reflection->SetBool(&debug_options, field, bvalue); return absl::OkStatus(); } + } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) { + int int_value; + if (absl::SimpleAtoi(value, &int_value)) { + if (field->is_repeated()) { + reflection->AddEnumValue(&debug_options, field, int_value); + } else { + reflection->SetEnumValue(&debug_options, field, int_value); + } + return absl::OkStatus(); + } else { + LOG(INFO) << "Name: " << value; + auto enum_desc = field->enum_type()->FindValueByName(value); + if (enum_desc != nullptr) { + LOG(INFO) << "Found nullptr"; + if (field->is_repeated()) { + reflection->AddEnum(&debug_options, field, enum_desc); + } else { + reflection->SetEnum(&debug_options, field, enum_desc); + } + } + } } return InvalidArgument( "While setting option %s, '%s' is not a valid %s value.", field->name(), diff --git a/xla/pjrt/pjrt_executable_test.cc b/xla/pjrt/pjrt_executable_test.cc index e3a0b3d09ba7f1..868ca80dec480c 100644 --- a/xla/pjrt/pjrt_executable_test.cc +++ b/xla/pjrt/pjrt_executable_test.cc @@ -93,7 +93,7 @@ TEST(ExecuteOptionsTest, SendRecvNotSupported) { "ExecuteOptions with send/recv calbacks is not serializable")); } -TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) { +TEST(ExecuteOptionsTest, ApplyOptionsCanParseStringsAndEnums) { using OptionOverride = std::variant; std::vector> env_override_options; env_override_options = {