Skip to content

Commit

Permalink
simplify pad(pad) (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
ftynse authored Mar 23, 2024
1 parent 64ed444 commit dfce78e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ cc_library(
"@stablehlo//:reference_ops",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_passes",
"@stablehlo//:stablehlo_type_inference",
"@xla//xla/mlir_hlo",
],
)
Expand Down
41 changes: 40 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3251,6 +3251,45 @@ struct DivZeroPad : public OpRewritePattern<mlir::stablehlo::DivOp> {
}
};

struct PadPad : public OpRewritePattern<mlir::stablehlo::PadOp> {
using OpRewritePattern<mlir::stablehlo::PadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::PadOp op,
PatternRewriter &rewriter) const final {
auto definingPad = op.getOperand().getDefiningOp<stablehlo::PadOp>();
if (!definingPad || definingPad.getPaddingValue() != op.getPaddingValue()) {
return rewriter.notifyMatchFailure(op, "no compatible defining pad");
}

auto allZero = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t v) { return v == 0; });
};

if (!allZero(op.getInteriorPadding()) ||
!allZero(definingPad.getInteriorPadding())) {
return rewriter.notifyMatchFailure(op, "cannot combine interior padding");
}

auto addLists = [](DenseI64ArrayAttr lhs, DenseI64ArrayAttr rhs) {
MLIRContext *context = lhs.getContext();
auto sum = llvm::map_to_vector(
llvm::zip(lhs.asArrayRef(), rhs.asArrayRef()),
[](auto &&pair) { return std::get<0>(pair) + std::get<1>(pair); });
return DenseI64ArrayAttr::get(context, sum);
};

rewriter.replaceOpWithNewOp<stablehlo::PadOp>(
op, definingPad.getOperand(), definingPad.getPaddingValue(),
addLists(op.getEdgePaddingLowAttr(),
definingPad.getEdgePaddingLowAttr()),
addLists(op.getEdgePaddingHighAttr(),
definingPad.getEdgePaddingHighAttr()),
addLists(op.getInteriorPaddingAttr(),
definingPad.getInteriorPaddingAttr()));
return success();
}
};

struct PadDotGeneral : public OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
using OpRewritePattern<mlir::stablehlo::DotGeneralOp>::OpRewritePattern;

Expand Down Expand Up @@ -3777,7 +3816,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
max_constant_expansion, context, PatternBenefit(65000));

patterns.add<ConvertConcat, DynamicUpdateToConcat, SliceOfDynamicUpdate,
SlicePad, DotReshapeDot, ConcatConstProp, ConcatFuse,
SlicePad, DotReshapeDot, ConcatConstProp, ConcatFuse, PadPad,
ConcatPushBinop<stablehlo::AddOp>,
ConcatPushBinop<stablehlo::MulOp>, ScatterToDynamicUpdateSlice,
ReduceConcat, SliceConcat, BinBroadcastSplat<stablehlo::AddOp>,
Expand Down
19 changes: 19 additions & 0 deletions test/lit_tests/padpad.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s

// CHECK-LABEL: @pad_pad
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x1xf32>, %[[ARG1:.+]]: tensor<f32>
// CHECK: stablehlo.pad %[[ARG0]], %[[ARG1]], low = [6, 0, 1], high = [1, 0, 8], interior = [0, 0, 0]
// CHECK-NOT: pad
func.func @pad_pad(%arg0: tensor<1x1x1xf32>, %arg1: tensor<f32>) -> tensor<8x1x10xf32> {
%0 = stablehlo.pad %arg0, %arg1, low = [2, 0, 0], high = [0, 0, 3], interior = [0, 0, 0] : (tensor<1x1x1xf32>, tensor<f32>) -> tensor<3x1x4xf32>
%1 = stablehlo.pad %0, %arg1, low = [4, 0, 1], high = [1, 0, 5], interior = [0, 0, 0] : (tensor<3x1x4xf32>, tensor<f32>) -> tensor<8x1x10xf32>
return %1 : tensor<8x1x10xf32>
}

// CHECK-LABEL: @pad_pad_interior2
// CHECK-COUNT-2: pad
func.func @pad_pad_interior2(%arg0: tensor<1x1x1xf32>, %arg1: tensor<f32>) -> tensor<10x1x10xf32> {
%0 = stablehlo.pad %arg0, %arg1, low = [2, 0, 0], high = [0, 0, 3], interior = [0, 1, 0] : (tensor<1x1x1xf32>, tensor<f32>) -> tensor<3x1x4xf32>
%1 = stablehlo.pad %0, %arg1, low = [4, 0, 1], high = [1, 0, 5], interior = [1, 0, 0] : (tensor<3x1x4xf32>, tensor<f32>) -> tensor<10x1x10xf32>
return %1 : tensor<10x1x10xf32>
}

0 comments on commit dfce78e

Please sign in to comment.