Skip to content

Commit

Permalink
const prop for log and logp1 (#218)
Browse files Browse the repository at this point in the history
* const prop for log and logp1

* add lit test

* crashing splat

* Fixed crashing splat
  • Loading branch information
vimarsh6739 authored Jan 9, 2025
1 parent bbb5a2e commit 1b6c935
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 2 deletions.
58 changes: 56 additions & 2 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2209,6 +2209,59 @@ struct DynamicUpdateSliceConstProp final
}
};

template <auto f>
LogicalResult unaryConstProp(Operation *op, PatternRewriter &rewriter) {
// return if not constant
DenseElementsAttr inputAttr;
if (!matchPattern(op->getOperand(0), m_Constant(&inputAttr)))
return failure();

stablehlo::Tensor inputTen;
RankedTensorType ty = cast<RankedTensorType>(op->getResultTypes()[0]);

if (inputAttr.isSplat()) {

ty = RankedTensorType::get(
{}, cast<ShapedType>(op->getResultTypes()[0]).getElementType());
inputTen = stablehlo::makeTensor(inputAttr.resizeSplat(ty));
} else {
inputTen = mlir::stablehlo::constantOp(inputAttr);
}
// get the resultType
auto resultType = ty.cast<ShapedType>();

// Convert constant to tensor, compute log, then convert back to attribute
auto out = fromTensor(f(inputTen, resultType));

if (inputAttr.isSplat()) {
out = out.resizeSplat(cast<ShapedType>(op->getResultTypes()[0]));
}
// Replace with new constant op containing the computed result
auto tmp = rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op->getResultTypes()[0], out);

return success();
}

struct LogConstProp final : OpRewritePattern<mlir::stablehlo::LogOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::LogOp op,
PatternRewriter &rewriter) const override {
return unaryConstProp<mlir::stablehlo::logOp>(op, rewriter);
}
};

struct LogPlusConstProp final : OpRewritePattern<mlir::stablehlo::Log1pOp> {

using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::Log1pOp op,
PatternRewriter &rewriter) const override {
return unaryConstProp<stablehlo::log1pOp>(op, rewriter);
}
};

struct ConcatConstProp final
: OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -7021,8 +7074,9 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
patterns.add<ConvertConcat, DynamicUpdateToConcat, SliceOfDynamicUpdate,
SliceElementwise, SliceReshapeElementwise, SlicePad,
SliceReshapePad, DotReshapeDot, ConcatConstProp,
DynamicUpdateSliceConstProp, ConcatFuse, ConcatToBroadcast,
PadPad, PadReshapePad, ConcatPushBinop<stablehlo::AddOp>,
DynamicUpdateSliceConstProp, LogConstProp, LogPlusConstProp,
ConcatFuse, ConcatToBroadcast, PadPad, PadReshapePad,
ConcatPushBinop<stablehlo::AddOp>,
ConcatPushBinop<stablehlo::MulOp>, ScatterToDynamicUpdateSlice,
ReduceConcat, ConcatSlice, SliceConcat, SliceReshapeConcat,
BinBroadcastSplat<stablehlo::AddOp>,
Expand Down
9 changes: 9 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def ApplyGatherSimplifyPatterns : EnzymeHLOPatternOp<
"gather_simplify"> {
let patterns = ["GatherSimplify"];
}
def ApplyLogConstProp : EnzymeHLOPatternOp<
"log_const_prop">{
let patterns = ["LogConstProp"];
}

def ApplyLogPlusConstProp : EnzymeHLOPatternOp<
"log_plus_one_const_prop">{
let patterns = ["LogPlusConstProp"];
}

// regular benefit
def ApplyConvertConcatPatterns : EnzymeHLOPatternOp<
Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def hlo_opts():
dot_reshape_dot<1>;
concat_const_prop<1>;
dynamic_update_slice_const_prop<1>;
log_const_prop<1>;
log_plus_one_const_prop<1>;
concat_fuse<1>;
pad_reshape_pad<1>;
pad_pad<1>;
Expand Down
49 changes: 49 additions & 0 deletions test/lit_tests/log_const_prop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(enzyme-hlo-opt)" | FileCheck %s

module {
func.func @log_f32() -> tensor<2xf32> {
%arg = stablehlo.constant dense<[1.000000e+00,2.000000e+00]> : tensor<2xf32>
%result = stablehlo.log %arg : tensor<2xf32>
func.return %result : tensor<2xf32>
}


func.func @log_f32_splat() -> tensor<2xf32> {
%arg = stablehlo.constant dense<2.000000e+00> : tensor<2xf32>
%result = stablehlo.log %arg : tensor<2xf32>
func.return %result : tensor<2xf32>
}

func.func @log_plus_one_op_test_f64() -> tensor<5xf64> {
%operand = stablehlo.constant dense<[0.0, -0.999, 7.0, 6.38905621, 15.0]> : tensor<5xf64>
%result = stablehlo.log_plus_one %operand : tensor<5xf64>
func.return %result : tensor<5xf64>
}

func.func @log_p1_f32_splat() -> tensor<2xf32> {
%arg = stablehlo.constant dense<1.000000e+00> : tensor<2xf32>
%result = stablehlo.log_plus_one %arg : tensor<2xf32>
func.return %result : tensor<2xf32>
}
}


// CHECK: func.func @log_f32() -> tensor<2xf32> {
// CHECK-NEXT: %cst = stablehlo.constant dense<[0.000000e+00, 0.693147182]> : tensor<2xf32>
// CHECK-NEXT: return %cst : tensor<2xf32>
// CHECK-NEXT: }

// CHECK: func.func @log_f32_splat() -> tensor<2xf32> {
// CHECK-NEXT: %cst = stablehlo.constant dense<0.693147182> : tensor<2xf32>
// CHECK-NEXT: return %cst : tensor<2xf32>
// CHECK-NEXT: }

// CHECK: func.func @log_plus_one_op_test_f64() -> tensor<5xf64> {
// CHECK-NEXT: %cst = stablehlo.constant dense<[0.000000e+00, -6.9077552789821359, 2.0794415416798357, 2.0000000150316017, 2.7725887222397811]> : tensor<5xf64>
// CHECK-NEXT: return %cst : tensor<5xf64>
// CHECK-NEXT: }

// CHECK: func.func @log_p1_f32_splat() -> tensor<2xf32> {
// CHECK-NEXT: %cst = stablehlo.constant dense<0.693147182> : tensor<2xf32>
// CHECK-NEXT: return %cst : tensor<2xf32>
// CHECK-NEXT: }

0 comments on commit 1b6c935

Please sign in to comment.