Skip to content

Commit

Permalink
Add opt to simplify select uses within if (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw authored Jan 9, 2025
1 parent 058fa1e commit bbb5a2e
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 17 deletions.
81 changes: 70 additions & 11 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5911,6 +5911,64 @@ struct CompareOpCanon final : OpRewritePattern<mlir::stablehlo::CompareOp> {
}
};

struct SelectOpUsedWithinIf final
: OpRewritePattern<mlir::stablehlo::SelectOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op,
PatternRewriter &rewriter) const override {
Value pred = op.getPred();
Value result = op.getResult();

if (pred.getType().cast<TensorType>().getShape().size() != 0)
return failure();

auto block = op->getBlock();

bool anyModified = false;

rewriter.replaceUsesWithIf(result, op.getOnTrue(), [&](auto &use) {
Operation *user = use.getOwner();
if (user->getBlock() == block)
return false;

Operation *p = user->getParentOp();
while (p && p != op) {
if (auto ifOp = dyn_cast<stablehlo::IfOp>(p)) {
if (ifOp.getPred() == pred &&
ifOp.getTrueBranch().isAncestor(user->getParentRegion())) {
anyModified = true;
return true;
}
}
p = p->getParentOp();
}
return false;
});

rewriter.replaceUsesWithIf(result, op.getOnFalse(), [&](auto &use) {
Operation *user = use.getOwner();
if (user->getBlock() == block)
return false;

Operation *p = user->getParentOp();
while (p && p != op) {
if (auto ifOp = dyn_cast<stablehlo::IfOp>(p)) {
if (ifOp.getPred() == pred &&
ifOp.getFalseBranch().isAncestor(user->getParentRegion())) {
anyModified = true;
return true;
}
}
p = p->getParentOp();
}
return false;
});

return success(anyModified);
}
};

struct SelectOpCanon final : OpRewritePattern<mlir::stablehlo::SelectOp> {
using OpRewritePattern::OpRewritePattern;
size_t max_constant_expansion;
Expand Down Expand Up @@ -7068,17 +7126,18 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
patterns.add<NoNan, NoNanSelfSubSimplify, NoNanAddSubSimplify>(context);
}

patterns
.add<CompareOpCanon, BroadcastInDimOpCanon, ConvertOpCanon,
DynamicBroadcastInDimOpNotActuallyDynamic,
ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimAllDimsNonExpanding, NoopReduceOpCanon,
EmptyReduceOpCanon, DynamicReshapeOpCanon, GetTupleElementOpCanon,
RealOpCanon, ImagOpCanon, ConjComplexNegate,
GetDimensionSizeOpCanon, GatherOpCanon, ReshapeOpCanon,
MergeConsecutiveReshapes, TransposeIsReshape, IfInline, IfToSelect,
ZeroExtentTensorCanon, ReorderElementwiseAndShapeOp,
DynamicGatherOpIsNotDynamic, DivideSqrtToMultiplyRsqrt>(context);
patterns.add<CompareOpCanon, BroadcastInDimOpCanon, ConvertOpCanon,
DynamicBroadcastInDimOpNotActuallyDynamic,
ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimAllDimsNonExpanding, NoopReduceOpCanon,
EmptyReduceOpCanon, DynamicReshapeOpCanon,
GetTupleElementOpCanon, RealOpCanon, ImagOpCanon,
ConjComplexNegate, GetDimensionSizeOpCanon, GatherOpCanon,
ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape,
SelectOpUsedWithinIf, IfInline, IfToSelect,
ZeroExtentTensorCanon, ReorderElementwiseAndShapeOp,
DynamicGatherOpIsNotDynamic, DivideSqrtToMultiplyRsqrt>(
context);
patterns.add<SelectOpCanon>(max_constant_expansion, context,
PatternBenefit(65000));
patterns.add<ConcatenateOpCanon>(max_constant_expansion, context,
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,11 @@ def IfToSelect : EnzymeHLOPatternOp<
let patterns = ["IfToSelect"];
}

def SelectOpUsedWithinIf : EnzymeHLOPatternOp<
"select_op_used_within_if"> {
let patterns = ["SelectOpUsedWithinIf"];
}

def ZeroExtentTensorCanonPatterns : EnzymeHLOPatternOp<
"zero_extent_tensor_canon"> {
let patterns = ["ZeroExtentTensorCanon"];
Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def hlo_opts():
reduce_concat<1>;
slice_concat<1>;
concat_slice<1>;
select_op_used_within_if<1>;
replace_neg_add_with_subtract<16>;
bin_broadcast_splat_add<1>;
Expand Down
11 changes: 5 additions & 6 deletions test/lit_tests/diffrules/stablehlo/if_remove.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ module {

// REVERSE: func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<i1>, %arg2: tensor<10xf32>) -> tensor<10xf32> {
// REVERSE-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf32>
// REVERSE-NEXT: %0 = stablehlo.select %arg1, %arg0, %cst : tensor<i1>, tensor<10xf32>
// REVERSE-NEXT: %1 = "stablehlo.if"(%arg1) ({
// REVERSE-NEXT: %2 = stablehlo.multiply %arg2, %0 : tensor<10xf32>
// REVERSE-NEXT: %3 = stablehlo.add %2, %2 : tensor<10xf32>
// REVERSE-NEXT: stablehlo.return %3 : tensor<10xf32>
// REVERSE-NEXT: %0 = "stablehlo.if"(%arg1) ({
// REVERSE-NEXT: %1 = stablehlo.multiply %arg2, %arg0 : tensor<10xf32>
// REVERSE-NEXT: %2 = stablehlo.add %1, %1 : tensor<10xf32>
// REVERSE-NEXT: stablehlo.return %2 : tensor<10xf32>
// REVERSE-NEXT: }, {
// REVERSE-NEXT: stablehlo.return %cst : tensor<10xf32>
// REVERSE-NEXT: }) : (tensor<i1>) -> tensor<10xf32>
// REVERSE-NEXT: return %1 : tensor<10xf32>
// REVERSE-NEXT: return %0 : tensor<10xf32>
// REVERSE-NEXT: }
26 changes: 26 additions & 0 deletions test/lit_tests/select_if.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s

module {
func.func @main(%pred: tensor<i1>, %a: tensor<f32>, %b: tensor<f32>) -> tensor<f32> {
%0 = stablehlo.select %pred, %a, %b : tensor<i1>, tensor<f32>
%1 = "stablehlo.if"(%pred) ({
%2 = stablehlo.add %0, %0 : tensor<f32>
"stablehlo.return"(%2) : (tensor<f32>) -> ()
}, {
%3 = stablehlo.add %0, %0 : tensor<f32>
"stablehlo.return"(%3) : (tensor<f32>) -> ()
}) : (tensor<i1>) -> tensor<f32>
return %1 : tensor<f32>
}
}

// CHECK: func.func @main(%arg0: tensor<i1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %0 = "stablehlo.if"(%arg0) ({
// CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<f32>
// CHECK-NEXT: stablehlo.return %1 : tensor<f32>
// CHECK-NEXT: }, {
// CHECK-NEXT: %1 = stablehlo.add %arg2, %arg2 : tensor<f32>
// CHECK-NEXT: stablehlo.return %1 : tensor<f32>
// CHECK-NEXT: }) : (tensor<i1>) -> tensor<f32>
// CHECK-NEXT: return %0 : tensor<f32>
// CHECK-NEXT: }

0 comments on commit bbb5a2e

Please sign in to comment.