Skip to content

Commit

Permalink
Add dynamic update to concat
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 10, 2024
1 parent 91e6faa commit df02ac1
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 11 deletions.
97 changes: 86 additions & 11 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,80 @@ struct DynamicUpdateSliceElim final
}
};

struct DynamicUpdateToConcat final
: OpRewritePattern<mlir::stablehlo::DynamicUpdateSliceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::DynamicUpdateSliceOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type)
return failure();

SmallVector<size_t> mismatches;
size_t idx = 0;
for (auto &&[start, update_size, res_size] :
llvm::zip(op.getStartIndices(), op.getUpdate().getType().getShape(),
op.getType().getShape())) {
DenseIntElementsAttr startattr;
if (!matchPattern(start, m_Constant(&startattr))) {
return failure();
}
int64_t startv = (*startattr.begin()).getSExtValue();
if (startv < 0)
return failure();

if (startv + update_size > res_size)
return failure();

if (startv == 0 && update_size == res_size) {
idx++;
continue;
}
mismatches.push_back(idx);
idx++;
}

if (mismatches.size() != 1)
return failure();
auto dim = mismatches[0];

DenseIntElementsAttr startattr;
if (!matchPattern(op.getStartIndices()[0], m_Constant(&startattr))) {
return failure();
}
int64_t startv = (*startattr.begin()).getSExtValue();

SmallVector<Value> toConcat;

if (startv != 0) {
SmallVector<int64_t> starts(op.getType().getShape().size(), 0);
SmallVector<int64_t> ends(op.getType().getShape().begin(),
op.getType().getShape().end());
SmallVector<int64_t> steps(op.getType().getShape().size(), 1);
ends[dim] = startv;
toConcat.push_back(rewriter.create<stablehlo::SliceOp>(
op.getLoc(), op.getOperand(), starts, ends, steps));
}
toConcat.push_back(op.getUpdate());
auto update_size = op.getUpdate().getType().getShape()[dim];
auto res_size = op.getType().getShape()[dim];
if (startv + update_size != res_size) {
SmallVector<int64_t> starts(op.getType().getShape().size(), 0);
SmallVector<int64_t> ends(op.getType().getShape().begin(),
op.getType().getShape().end());
SmallVector<int64_t> steps(op.getType().getShape().size(), 1);
starts[dim] = startv + update_size;
toConcat.push_back(rewriter.create<stablehlo::SliceOp>(
op.getLoc(), op.getOperand(), starts, ends, steps));
}

rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(op, op.getType(),
toConcat, dim);
return success();
}
};

struct SliceOfDynamicUpdate final : OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -1475,17 +1549,18 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
void runOnOperation() override {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns.add<DynamicSliceToStatic, DynamicUpdateSliceElim,
SliceOfDynamicUpdate, SlicePad, SliceSlice, AddPad,
PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ReduceConcat, SliceConcat, SliceSimplification, CosSimplify,
SinSimplify, SqrtSimplify, AddSimplify, SubSimplify,
AndSimplify, OrSimplify, NegateSimplify, MulSimplify,
DivSimplify, PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
patterns
.add<DynamicSliceToStatic, DynamicUpdateSliceElim,
DynamicUpdateToConcat, SliceOfDynamicUpdate, SlicePad, SliceSlice,
AddPad, PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ReduceConcat, SliceConcat, SliceSimplification, CosSimplify,
SinSimplify, SqrtSimplify, AddSimplify, SubSimplify, AndSimplify,
OrSimplify, NegateSimplify, MulSimplify, DivSimplify, PowSimplify,
BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
mlir::stablehlo::populateStablehloCanonicalizationPatterns(context,
&patterns);

Expand Down
17 changes: 17 additions & 0 deletions test/lit_tests/dynamicupdatetoconcat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {
func.func @main(%a : tensor<4x5xf32>, %b : tensor<2x5xf32>) -> tensor<4x5xf32> {
%c1 = stablehlo.constant dense<1> : tensor<i32>
%c0 = stablehlo.constant dense<0> : tensor<i32>
%r = stablehlo.dynamic_update_slice %a, %b, %c1, %c0 : (tensor<4x5xf32>, tensor<2x5xf32>, tensor<i32>, tensor<i32>) -> tensor<4x5xf32>
return %r : tensor<4x5xf32>
}
}

// CHECK: func.func @main(%arg0: tensor<4x5xf32>, %arg1: tensor<2x5xf32>) -> tensor<4x5xf32> {
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:1, 0:5] : (tensor<4x5xf32>) -> tensor<1x5xf32>
// CHECK-NEXT: %1 = stablehlo.slice %arg0 [3:4, 0:5] : (tensor<4x5xf32>) -> tensor<1x5xf32>
// CHECK-NEXT: %2 = stablehlo.concatenate %0, %arg1, %1, dim = 0 : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK-NEXT: return %2 : tensor<4x5xf32>
// CHECK-NEXT: }

0 comments on commit df02ac1

Please sign in to comment.