diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 65b5d25cf..b12059502 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -131,6 +131,33 @@ struct DynamicSliceToStatic final } }; +struct DynamicUpdateSliceElim final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::DynamicUpdateSliceOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + if (op.getOperand().getType() != type) + return failure(); + + for (auto start : op.getStartIndices()) { + DenseIntElementsAttr startattr; + if (!matchPattern(start, m_Constant(&startattr))) { + return failure(); + } + int64_t startv = (*startattr.begin()).getSExtValue(); + if (startv != 0) + return failure(); + } + rewriter.replaceOp(op, op.getUpdate()); + return success(); + } +}; + // slice(pad x) -> pad(slice x) struct SlicePad final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1332,16 +1359,17 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add, - BinBroadcastSplat, - BinBroadcastSplat, - BinBroadcastSplat>(context); + patterns + .add, + BinBroadcastSplat, + BinBroadcastSplat, + BinBroadcastSplat>(context); mlir::stablehlo::populateStablehloCanonicalizationPatterns(context, &patterns); diff --git a/test/lit_tests/dynamicupdateelim.mlir b/test/lit_tests/dynamicupdateelim.mlir new file mode 100644 index 000000000..36fd65519 --- /dev/null +++ b/test/lit_tests/dynamicupdateelim.mlir @@ -0,0 +1,15 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s + +module { + + func.func @main(%a : tensor<4x5xf32>, %b : tensor<4x5xf32>) -> tensor<4x5xf32> { + + %c0 = stablehlo.constant dense<0> : tensor + %r = stablehlo.dynamic_update_slice %a, %b, %c0, %c0 : (tensor<4x5xf32>, tensor<4x5xf32>, tensor, tensor) -> tensor<4x5xf32> + return %r : tensor<4x5xf32> + } +} + +// CHECK: func.func @main(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { +// CHECK-NEXT: return %arg1 : tensor<4x5xf32> +// CHECK-NEXT: }