Skip to content

Commit

Permalink
Add support for parsing enum values in `ExecuteOptions::ApplyAllOptio…
Browse files Browse the repository at this point in the history
…nOverrides`.

This is needed to support parsing the `xla_gpu_disable_async_collectives` flag, which is an enum value.

PiperOrigin-RevId: 675656859
  • Loading branch information
rohan100jain authored and Google-ML-Automation committed Sep 17, 2024
1 parent 3406c60 commit 56363ec
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
46 changes: 46 additions & 0 deletions xla/pjrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,31 @@ absl::Status CompileOptions::ApplyOption(const std::string& key,
std::holds_alternative<double>(value)) {
reflection->SetDouble(&debug_options, xla_field, std::get<double>(value));
return absl::OkStatus();
} else if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) {
LOG(INFO) << "Enum processing";
if (std::holds_alternative<int64_t>(value)) {
LOG(INFO) << "Int value";
if (xla_field->is_repeated()) {
reflection->AddEnumValue(&debug_options, xla_field,
std::get<int64_t>(value));
} else {
reflection->SetEnumValue(&debug_options, xla_field,
std::get<int64_t>(value));
}
} else {
LOG(INFO) << "Name: " << std::get<std::string>(value);
auto enum_desc = xla_field->enum_type()->FindValueByName(
std::get<std::string>(value));
if (enum_desc != nullptr) {
LOG(INFO) << "Found nullptr";
if (xla_field->is_repeated()) {
reflection->AddEnum(&debug_options, xla_field, enum_desc);
} else {
reflection->SetEnum(&debug_options, xla_field, enum_desc);
}
}
}
return absl::OkStatus();
} else {
return InvalidArgument(
"While setting option %s, '%s' is not a valid %s value.", key,
Expand Down Expand Up @@ -636,6 +661,27 @@ absl::Status CompileOptions::ApplyOptionFromString(
reflection->SetBool(&debug_options, field, bvalue);
return absl::OkStatus();
}
} else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) {
int int_value;
if (absl::SimpleAtoi(value, &int_value)) {
if (field->is_repeated()) {
reflection->AddEnumValue(&debug_options, field, int_value);
} else {
reflection->SetEnumValue(&debug_options, field, int_value);
}
return absl::OkStatus();
} else {
LOG(INFO) << "Name: " << value;
auto enum_desc = field->enum_type()->FindValueByName(value);
if (enum_desc != nullptr) {
LOG(INFO) << "Found nullptr";
if (field->is_repeated()) {
reflection->AddEnum(&debug_options, field, enum_desc);
} else {
reflection->SetEnum(&debug_options, field, enum_desc);
}
}
}
}
return InvalidArgument(
"While setting option %s, '%s' is not a valid %s value.", field->name(),
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/pjrt_executable_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ TEST(ExecuteOptionsTest, SendRecvNotSupported) {
"ExecuteOptions with send/recv calbacks is not serializable"));
}

TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) {
TEST(ExecuteOptionsTest, ApplyOptionsCanParseStringsAndEnums) {
using OptionOverride = std::variant<std::string, bool, int64_t, double>;
std::vector<std::pair<std::string, OptionOverride>> env_override_options;
env_override_options = {
Expand Down

0 comments on commit 56363ec

Please sign in to comment.