diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index ff6a4f50a..59201a5ae 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -5911,6 +5911,64 @@ struct CompareOpCanon final : OpRewritePattern { } }; +struct SelectOpUsedWithinIf final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op, + PatternRewriter &rewriter) const override { + Value pred = op.getPred(); + Value result = op.getResult(); + + if (pred.getType().cast().getShape().size() != 0) + return failure(); + + auto block = op->getBlock(); + + bool anyModified = false; + + rewriter.replaceUsesWithIf(result, op.getOnTrue(), [&](auto &use) { + Operation *user = use.getOwner(); + if (user->getBlock() == block) + return false; + + Operation *p = user->getParentOp(); + while (p && p != op) { + if (auto ifOp = dyn_cast(p)) { + if (ifOp.getPred() == pred && + ifOp.getTrueBranch().isAncestor(user->getParentRegion())) { + anyModified = true; + return true; + } + } + p = p->getParentOp(); + } + return false; + }); + + rewriter.replaceUsesWithIf(result, op.getOnFalse(), [&](auto &use) { + Operation *user = use.getOwner(); + if (user->getBlock() == block) + return false; + + Operation *p = user->getParentOp(); + while (p && p != op) { + if (auto ifOp = dyn_cast(p)) { + if (ifOp.getPred() == pred && + ifOp.getFalseBranch().isAncestor(user->getParentRegion())) { + anyModified = true; + return true; + } + } + p = p->getParentOp(); + } + return false; + }); + + return success(anyModified); + } +}; + struct SelectOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; size_t max_constant_expansion; @@ -7068,17 +7126,18 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { patterns.add(context); } - patterns - .add(context); + patterns.add( + context); patterns.add(max_constant_expansion, context, PatternBenefit(65000)); patterns.add(max_constant_expansion, context, diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index c45ecbd82..4f0dcfce1 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -631,6 +631,11 @@ def IfToSelect : EnzymeHLOPatternOp< let patterns = ["IfToSelect"]; } +def SelectOpUsedWithinIf : EnzymeHLOPatternOp< + "select_op_used_within_if"> { + let patterns = ["SelectOpUsedWithinIf"]; +} + def ZeroExtentTensorCanonPatterns : EnzymeHLOPatternOp< "zero_extent_tensor_canon"> { let patterns = ["ZeroExtentTensorCanon"]; diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 05e5f3d46..f8ca85332 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -181,6 +181,7 @@ def hlo_opts(): reduce_concat<1>; slice_concat<1>; concat_slice<1>; +select_op_used_within_if<1>; replace_neg_add_with_subtract<16>; bin_broadcast_splat_add<1>; diff --git a/test/lit_tests/diffrules/stablehlo/if_remove.mlir b/test/lit_tests/diffrules/stablehlo/if_remove.mlir index d624e5f1a..3ddaa3a02 100644 --- a/test/lit_tests/diffrules/stablehlo/if_remove.mlir +++ b/test/lit_tests/diffrules/stablehlo/if_remove.mlir @@ -18,13 +18,12 @@ module { // REVERSE: func.func @main(%arg0: tensor<10xf32>, %arg1: tensor, %arg2: tensor<10xf32>) -> tensor<10xf32> { // REVERSE-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf32> -// REVERSE-NEXT: %0 = stablehlo.select %arg1, %arg0, %cst : tensor, tensor<10xf32> -// REVERSE-NEXT: %1 = "stablehlo.if"(%arg1) ({ -// REVERSE-NEXT: %2 = stablehlo.multiply %arg2, %0 : tensor<10xf32> -// REVERSE-NEXT: %3 = stablehlo.add %2, %2 : tensor<10xf32> -// REVERSE-NEXT: stablehlo.return %3 : tensor<10xf32> +// REVERSE-NEXT: %0 = "stablehlo.if"(%arg1) ({ +// REVERSE-NEXT: %1 = stablehlo.multiply %arg2, %arg0 : tensor<10xf32> +// REVERSE-NEXT: %2 = stablehlo.add %1, %1 : tensor<10xf32> +// REVERSE-NEXT: stablehlo.return %2 : tensor<10xf32> // REVERSE-NEXT: }, { // REVERSE-NEXT: stablehlo.return %cst : tensor<10xf32> // REVERSE-NEXT: }) : (tensor) -> tensor<10xf32> -// REVERSE-NEXT: return %1 : tensor<10xf32> +// REVERSE-NEXT: return %0 : tensor<10xf32> // REVERSE-NEXT: } diff --git a/test/lit_tests/select_if.mlir b/test/lit_tests/select_if.mlir new file mode 100644 index 000000000..c203be8ff --- /dev/null +++ b/test/lit_tests/select_if.mlir @@ -0,0 +1,26 @@ +// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s + +module { + func.func @main(%pred: tensor, %a: tensor, %b: tensor) -> tensor { + %0 = stablehlo.select %pred, %a, %b : tensor, tensor + %1 = "stablehlo.if"(%pred) ({ + %2 = stablehlo.add %0, %0 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + %3 = stablehlo.add %0, %0 : tensor + "stablehlo.return"(%3) : (tensor) -> () + }) : (tensor) -> tensor + return %1 : tensor + } +} + +// CHECK: func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK-NEXT: %0 = "stablehlo.if"(%arg0) ({ +// CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor +// CHECK-NEXT: stablehlo.return %1 : tensor +// CHECK-NEXT: }, { +// CHECK-NEXT: %1 = stablehlo.add %arg2, %arg2 : tensor +// CHECK-NEXT: stablehlo.return %1 : tensor +// CHECK-NEXT: }) : (tensor) -> tensor +// CHECK-NEXT: return %0 : tensor +// CHECK-NEXT: }