Skip to content

Commit

Permalink
Add a rule to algebraic simplifier to replace A - A with 0.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675732243
  • Loading branch information
fhoushmand authored and Google-ML-Automation committed Oct 8, 2024
1 parent b9b6ea9 commit 5832697
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 1 deletion.
7 changes: 7 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_experimental_disable_binary_libraries(false);
opts.set_xla_experimental_ignore_channel_id(false);
opts.set_xla_gpu_dot_merger_threshold_mb(32);
opts.set_xla_enable_fast_math(false);
return opts;
}

Expand Down Expand Up @@ -2003,6 +2004,12 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int32_setter_for(&DebugOptions::set_xla_gpu_dot_merger_threshold_mb),
debug_options->xla_gpu_dot_merger_threshold_mb(),
"Dot merger pass threshold to be set in MB."));

flag_list->push_back(
tsl::Flag("xla_enable_fast_math",
bool_setter_for(&DebugOptions::set_xla_enable_fast_math),
debug_options->xla_enable_fast_math(),
"Enable optimizations that assume finite math, i.e., no NaN."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
7 changes: 7 additions & 0 deletions xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2073,6 +2073,13 @@ absl::Status AlgebraicSimplifierVisitor::HandleConstant(
absl::Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
HloInstruction *lhs, *rhs;
CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
// A - A => 0
if (options_.enable_fast_math() ||
ShapeUtil::ElementIsIntegral(sub->shape())) {
if (lhs == rhs) {
return ReplaceInstruction(sub, MakeScalarLike(sub, 0));
}
}
// A - 0 => A
VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(sub, lhs)) {
Expand Down
7 changes: 7 additions & 0 deletions xla/service/algebraic_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ class AlgebraicSimplifierOptions {
return disable_dynamic_slice_to_slice_conversion_;
}

// Option to set finite math.
void set_enable_fast_math(bool enable_fast_math) {
enable_fast_math_ = enable_fast_math;
}
bool enable_fast_math() const { return enable_fast_math_; }

private:
// Metadata struct can be used to store any metadata information encapsulated
// with the AlgebraicSimplifierOptions that can be later used in an
Expand Down Expand Up @@ -347,6 +353,7 @@ class AlgebraicSimplifierOptions {
double raise_slice_and_reduce_through_dot_threshold_{2.0};
bool use_convert_constant_folding_{false};
bool disable_dynamic_slice_to_slice_conversion_{false};
bool enable_fast_math_{false};
Metadata metadata_;
};

Expand Down
35 changes: 35 additions & 0 deletions xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,41 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) {
m::ConstantScalar(2.0))))));
}

TEST_F(AlgebraicSimplifierTest, ReplaceSubtractOfEqualOperandsWithZero) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[] parameter(0)
ROOT sub = f32[] subtract(p0, p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
AlgebraicSimplifierOptions options;
options.set_enable_fast_math(true);
AlgebraicSimplifier simplifier(options);
ASSERT_TRUE(simplifier.Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::ConstantScalar(0.0)));
}

TEST_F(AlgebraicSimplifierTest,
ReplaceSubtractOfEqualOperandsWithBroadcastZero) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[512,20] parameter(0)
ROOT sub = f32[512,20] subtract(p0, p0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
AlgebraicSimplifierOptions options;
options.set_enable_fast_math(true);
AlgebraicSimplifier simplifier(options);
ASSERT_TRUE(simplifier.Run(m.get()).value());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast()));
}

TEST_F(AlgebraicSimplifierTest, SubAddReassociateMergeConstants) {
const char* kModuleStr = R"(
HloModule m
Expand Down
5 changes: 4 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ message DebugOptions {
// When true, XLA:CPU uses the thunk runtime to execute compiled program.
bool xla_cpu_use_thunk_runtime = 298;

// Enabling this will enable optimizations that ignore the possibility of NaN.
bool xla_enable_fast_math = 335;

// The number of parts to split the LLVM module into before codegen. This
// allows XLA to compile all parts in parallel, and resolve kernel symbols
// from different dynamic libraries.
Expand Down Expand Up @@ -1002,7 +1005,7 @@ message DebugOptions {
// loop by a factor of two if a collective op is present.
bool xla_gpu_enable_heuristic_pass_configuration = 332;

// Next id: 335
// Next id: 336

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 5832697

Please sign in to comment.