Skip to content

Commit

Permalink
Optimize away dynamic slice
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 10, 2024
1 parent 9d7a82b commit 03f8fcf
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
48 changes: 38 additions & 10 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,33 @@ struct DynamicSliceToStatic final
}
};

struct DynamicUpdateSliceElim final
: OpRewritePattern<mlir::stablehlo::DynamicUpdateSliceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::DynamicUpdateSliceOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type)
return failure();

if (op.getOperand().getType() != type)
return failure();

for (auto start : op.getStartIndices()) {
DenseIntElementsAttr startattr;
if (!matchPattern(start, m_Constant(&startattr))) {
return failure();
}
int64_t startv = (*startattr.begin()).getSExtValue();
if (startv != 0)
return failure();
}
rewriter.replaceOp(op, op.getUpdate());
return success();
}
};

// slice(pad x) -> pad(slice x)
struct SlicePad final : OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -1332,16 +1359,17 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
void runOnOperation() override {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns.add<DynamicSliceToStatic, SlicePad, SliceSlice, AddPad,
PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ReduceConcat, SliceConcat, SliceSimplification, CosSimplify,
SinSimplify, SqrtSimplify, AddSimplify, SubSimplify,
AndSimplify, OrSimplify, NegateSimplify, MulSimplify,
DivSimplify, PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
patterns
.add<DynamicSliceToStatic, DynamicUpdateSliceElim, SlicePad, SliceSlice,
AddPad, PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ReduceConcat, SliceConcat, SliceSimplification, CosSimplify,
SinSimplify, SqrtSimplify, AddSimplify, SubSimplify, AndSimplify,
OrSimplify, NegateSimplify, MulSimplify, DivSimplify, PowSimplify,
BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
mlir::stablehlo::populateStablehloCanonicalizationPatterns(context,
&patterns);

Expand Down
15 changes: 15 additions & 0 deletions test/lit_tests/dynamicupdateelim.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {

func.func @main(%a : tensor<4x5xf32>, %b : tensor<4x5xf32>) -> tensor<4x5xf32> {

%c0 = stablehlo.constant dense<0> : tensor<i32>
%r = stablehlo.dynamic_update_slice %a, %b, %c0, %c0 : (tensor<4x5xf32>, tensor<4x5xf32>, tensor<i32>, tensor<i32>) -> tensor<4x5xf32>
return %r : tensor<4x5xf32>
}
}

// CHECK: func.func @main(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> {
// CHECK-NEXT: return %arg1 : tensor<4x5xf32>
// CHECK-NEXT: }

0 comments on commit 03f8fcf

Please sign in to comment.