Skip to content

Commit

Permalink
Add support for parsing enum values in `ExecuteOptions::ApplyAllOptio…
Browse files Browse the repository at this point in the history
…nOverrides`.

This is needed to support parsing the `xla_gpu_disable_async_collectives` flag, which is an enum value.

PiperOrigin-RevId: 675656859
  • Loading branch information
rohan100jain authored and Google-ML-Automation committed Sep 18, 2024
1 parent 0b5b884 commit 97b38ec
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
40 changes: 40 additions & 0 deletions xla/pjrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,27 @@ absl::Status CompileOptions::ApplyOption(const std::string& key,
std::holds_alternative<double>(value)) {
reflection->SetDouble(&debug_options, xla_field, std::get<double>(value));
return absl::OkStatus();
} else if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) {
if (std::holds_alternative<int64_t>(value)) {
if (xla_field->is_repeated()) {
reflection->AddEnumValue(&debug_options, xla_field,
std::get<int64_t>(value));
} else {
reflection->SetEnumValue(&debug_options, xla_field,
std::get<int64_t>(value));
}
} else {
auto enum_desc = xla_field->enum_type()->FindValueByName(
std::get<std::string>(value));
if (enum_desc != nullptr) {
if (xla_field->is_repeated()) {
reflection->AddEnum(&debug_options, xla_field, enum_desc);
} else {
reflection->SetEnum(&debug_options, xla_field, enum_desc);
}
}
}
return absl::OkStatus();
} else {
return InvalidArgument(
"While setting option %s, '%s' is not a valid %s value.", key,
Expand Down Expand Up @@ -636,6 +657,25 @@ absl::Status CompileOptions::ApplyOptionFromString(
reflection->SetBool(&debug_options, field, bvalue);
return absl::OkStatus();
}
} else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_ENUM) {
int int_value;
if (absl::SimpleAtoi(value, &int_value)) {
if (field->is_repeated()) {
reflection->AddEnumValue(&debug_options, field, int_value);
} else {
reflection->SetEnumValue(&debug_options, field, int_value);
}
return absl::OkStatus();
} else {
auto enum_desc = field->enum_type()->FindValueByName(value);
if (enum_desc != nullptr) {
if (field->is_repeated()) {
reflection->AddEnum(&debug_options, field, enum_desc);
} else {
reflection->SetEnum(&debug_options, field, enum_desc);
}
}
}
}
return InvalidArgument(
"While setting option %s, '%s' is not a valid %s value.", field->name(),
Expand Down
17 changes: 16 additions & 1 deletion xla/pjrt/pjrt_executable_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,26 @@ TEST(ExecuteOptionsTest, SendRecvNotSupported) {
"ExecuteOptions with send/recv calbacks is not serializable"));
}

TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) {
TEST(ExecuteOptionsTest, ApplyOptionsCanParseStringsAndEnums) {
using OptionOverride = std::variant<std::string, bool, int64_t, double>;
std::vector<std::pair<std::string, OptionOverride>> env_override_options;
env_override_options = {
{"xla_gpu_use_runtime_fusion", std::string("True")},
{"xla_gpu_graph_min_graph_size", std::string("2")},
{"xla_gpu_disable_async_collectives", std::string("2")},
{"xla_gpu_disable_async_collectives",
std::string("COLLECTIVEBROADCAST")},
{"xla_gpu_redzone_scratch_max_megabytes", std::string("3400")},
{"xla_gpu_auto_spmd_partitioning_memory_budget_ratio", 0.9},
{"xla_gpu_pgle_profile_file_or_directory_path", std::string("abc")}};
// env_override_options = {
// {"xla_gpu_use_runtime_fusion", std::string("True")},
// {"xla_gpu_graph_min_graph_size", std::string("2")},
// {"xla_gpu_redzone_scratch_max_megabytes", std::string("3400")},
// {"xla_gpu_auto_spmd_partitioning_memory_budget_ratio", 0.9},
// {"xla_gpu_pgle_profile_file_or_directory_path", std::string("abc")},
// {"xla_gpu_disable_async_collectives", 1},
// {"xla_gpu_disable_async_collectives", std::string("2")},
CompileOptions src;
src.env_option_overrides = env_override_options;
auto s = src.ApplyAllOptionOverrides();
Expand All @@ -112,6 +123,10 @@ TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) {
EXPECT_FLOAT_EQ(
debug_options.xla_gpu_auto_spmd_partitioning_memory_budget_ratio(), 0.9);
EXPECT_EQ(debug_options.xla_gpu_pgle_profile_file_or_directory_path(), "abc");
// EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives().size(), 3);
// EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives()[0], 1);
// EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives()[1], 2);
// EXPECT_EQ(debug_options.xla_gpu_disable_async_collectives()[2], 4);
}

TEST(CompiledMemoryStatsTest, Serialization) {
Expand Down

0 comments on commit 97b38ec

Please sign in to comment.