From df02ac1aad2b3eb22fd2c8d74b001bd532be9803 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 10 Mar 2024 16:00:32 -0400 Subject: [PATCH] Add dynamic update to concat --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 97 ++++++++++++++++++++--- test/lit_tests/dynamicupdatetoconcat.mlir | 17 ++++ 2 files changed, 103 insertions(+), 11 deletions(-) create mode 100644 test/lit_tests/dynamicupdatetoconcat.mlir diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 5304e81de..1a1b9a034 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -158,6 +158,80 @@ struct DynamicUpdateSliceElim final } }; +struct DynamicUpdateToConcat final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::DynamicUpdateSliceOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + SmallVector mismatches; + size_t idx = 0; + for (auto &&[start, update_size, res_size] : + llvm::zip(op.getStartIndices(), op.getUpdate().getType().getShape(), + op.getType().getShape())) { + DenseIntElementsAttr startattr; + if (!matchPattern(start, m_Constant(&startattr))) { + return failure(); + } + int64_t startv = (*startattr.begin()).getSExtValue(); + if (startv < 0) + return failure(); + + if (startv + update_size > res_size) + return failure(); + + if (startv == 0 && update_size == res_size) { + idx++; + continue; + } + mismatches.push_back(idx); + idx++; + } + + if (mismatches.size() != 1) + return failure(); + auto dim = mismatches[0]; + + DenseIntElementsAttr startattr; + if (!matchPattern(op.getStartIndices()[0], m_Constant(&startattr))) { + return failure(); + } + int64_t startv = (*startattr.begin()).getSExtValue(); + + SmallVector toConcat; + + if (startv != 0) { + SmallVector starts(op.getType().getShape().size(), 0); + SmallVector ends(op.getType().getShape().begin(), + op.getType().getShape().end()); + SmallVector steps(op.getType().getShape().size(), 1); + ends[dim] = startv; + toConcat.push_back(rewriter.create( + op.getLoc(), op.getOperand(), starts, ends, steps)); + } + toConcat.push_back(op.getUpdate()); + auto update_size = op.getUpdate().getType().getShape()[dim]; + auto res_size = op.getType().getShape()[dim]; + if (startv + update_size != res_size) { + SmallVector starts(op.getType().getShape().size(), 0); + SmallVector ends(op.getType().getShape().begin(), + op.getType().getShape().end()); + SmallVector steps(op.getType().getShape().size(), 1); + starts[dim] = startv + update_size; + toConcat.push_back(rewriter.create( + op.getLoc(), op.getOperand(), starts, ends, steps)); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), + toConcat, dim); + return success(); + } +}; + struct SliceOfDynamicUpdate final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1475,17 +1549,18 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add, - BinBroadcastSplat, - BinBroadcastSplat, - BinBroadcastSplat>(context); + patterns + .add, + BinBroadcastSplat, + BinBroadcastSplat, + BinBroadcastSplat>(context); mlir::stablehlo::populateStablehloCanonicalizationPatterns(context, &patterns); diff --git a/test/lit_tests/dynamicupdatetoconcat.mlir b/test/lit_tests/dynamicupdatetoconcat.mlir new file mode 100644 index 000000000..3f2781efe --- /dev/null +++ b/test/lit_tests/dynamicupdatetoconcat.mlir @@ -0,0 +1,17 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s + +module { + func.func @main(%a : tensor<4x5xf32>, %b : tensor<2x5xf32>) -> tensor<4x5xf32> { + %c1 = stablehlo.constant dense<1> : tensor + %c0 = stablehlo.constant dense<0> : tensor + %r = stablehlo.dynamic_update_slice %a, %b, %c1, %c0 : (tensor<4x5xf32>, tensor<2x5xf32>, tensor, tensor) -> tensor<4x5xf32> + return %r : tensor<4x5xf32> + } +} + +// CHECK: func.func @main(%arg0: tensor<4x5xf32>, %arg1: tensor<2x5xf32>) -> tensor<4x5xf32> { +// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:1, 0:5] : (tensor<4x5xf32>) -> tensor<1x5xf32> +// CHECK-NEXT: %1 = stablehlo.slice %arg0 [3:4, 0:5] : (tensor<4x5xf32>) -> tensor<1x5xf32> +// CHECK-NEXT: %2 = stablehlo.concatenate %0, %arg1, %1, dim = 0 : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> +// CHECK-NEXT: return %2 : tensor<4x5xf32> +// CHECK-NEXT: }