Skip to content

Commit

Permalink
Pow to sqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 10, 2024
1 parent 650d115 commit 91e6faa
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,24 @@ struct PowSimplify : public OpRewritePattern<mlir::stablehlo::PowOp> {
return success();
}

// pow(X, 0.5) -> sqrt(X)
{
DenseFPElementsAttr rhs;
if (matchPattern(op.getRhs(), m_Constant(&rhs))) {
bool allHalf = true;
for (auto v : rhs) {
if (!v.isExactlyValue(0.5)) {
allHalf = false;
break;
}
}
if (allHalf) {
rewriter.replaceOpWithNewOp<stablehlo::SqrtOp>(op, op.getLhs());
return success();
}
}
}

return failure();
}
};
Expand Down
15 changes: 15 additions & 0 deletions test/lit_tests/powtosqrt.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {

func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xf32> {
%half = stablehlo.constant dense<5.000000e-01> : tensor<2x2xf32>
%pd = stablehlo.power %a, %half : tensor<2x2xf32>
return %pd : tensor<2x2xf32>
}
}

// CHECK: func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-NEXT: %0 = stablehlo.sqrt %arg0 : tensor<2x2xf32>
// CHECK-NEXT: return %0 : tensor<2x2xf32>
// CHECK-NEXT: }

0 comments on commit 91e6faa

Please sign in to comment.