Skip to content

Commit

Permalink
Support nan removal
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 10, 2024
1 parent df02ac1 commit 5c65020
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,40 @@ template <typename T> struct BinBroadcastSplat final : OpRewritePattern<T> {
}
};

struct AllFinite : public OpRewritePattern<mlir::stablehlo::IsFiniteOp> {
using OpRewritePattern<mlir::stablehlo::IsFiniteOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::IsFiniteOp op,
PatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), makeAttr(op.getType(), 1).cast<ElementsAttr>());
return success();
}
};

struct NoNan : public OpRewritePattern<mlir::stablehlo::CompareOp> {
using OpRewritePattern<mlir::stablehlo::CompareOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::CompareOp op,
PatternRewriter &rewriter) const final {
if (op.getLhs() == op.getRhs()) {
if (op.getComparisonDirection() ==
mlir::stablehlo::ComparisonDirection::EQ) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), makeAttr(op.getType(), 1).cast<ElementsAttr>());
return success();
}
if (op.getComparisonDirection() ==
mlir::stablehlo::ComparisonDirection::NE) {
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op.getType(), makeAttr(op.getType(), 0).cast<ElementsAttr>());
return success();
}
}
return failure();
}
};

struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {

void runOnOperation() override {
Expand All @@ -1561,6 +1595,10 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
if (all_finite)
patterns.add<AllFinite>(context);
if (no_nan || all_finite)
patterns.add<NoNan>(context);
mlir::stablehlo::populateStablehloCanonicalizationPatterns(context,
&patterns);

Expand Down
16 changes: 16 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ def EnzymeHLOOptPass : Pass<"enzyme-hlo-opt"> {
"tensor::TensorDialect"
];
let constructor = "mlir::enzyme::createEnzymeHLOOptPass()";
let options = [
Option<
/*C++ variable name=*/"all_finite",
/*CLI argument=*/"all_finite",
/*type=*/"bool",
/*default=*/"false",
/*description=*/"Whether to raise to assume all variables are finite"
>,
Option<
/*C++ variable name=*/"no_nan",
/*CLI argument=*/"no_nan",
/*type=*/"bool",
/*default=*/"false",
/*description=*/"Whether to raise to assume no variables are nan"
>
];
}

def EnzymeHLOUnrollPass : Pass<"enzyme-hlo-unroll"> {
Expand Down
23 changes: 23 additions & 0 deletions test/lit_tests/isfinite.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{all_finite=true})" %s | FileCheck %s --check-prefix=REMOVED
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{all_finite=false})" %s | FileCheck %s --check-prefix=SAME
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt)" %s | FileCheck %s --check-prefix=SAME
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{no_nan=true})" %s | FileCheck %s --check-prefix=SAME
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{no_nan=false})" %s | FileCheck %s --check-prefix=SAME

module {

func.func @main(%a : tensor<2x2xf32>) -> tensor<2x2xi1> {
%r = stablehlo.is_finite %a : (tensor<2x2xf32>) -> tensor<2x2xi1>
return %r : tensor<2x2xi1>
}
}

// REMOVED: func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xi1> {
// REMOVED-NEXT: %0 = stablehlo.constant dense<true> : tensor<2x2xi1>
// REMOVED-NEXT: return %0 : tensor<2x2xi1>
// REMOVED-NEXT: }

// SAME: func.func @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xi1> {
// SAME-NEXT: %0 = stablehlo.is_finite %arg0 : (tensor<2x2xf32>) -> tensor<2x2xi1>
// SAME-NEXT: return %0 : tensor<2x2xi1>
// SAME-NEXT: }
26 changes: 26 additions & 0 deletions test/lit_tests/isnan.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{all_finite=true})" %s | FileCheck %s --check-prefix=REMOVED
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{all_finite=false})" %s | FileCheck %s --check-prefix=SAME
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt)" %s | FileCheck %s --check-prefix=SAME
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{no_nan=true})" %s | FileCheck %s --check-prefix=REMOVED
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{no_nan=false})" %s | FileCheck %s --check-prefix=SAME

module {

func.func @main(%a : tensor<2x2xf32>) -> (tensor<2x2xi1>, tensor<2x2xi1>) {
%eq = stablehlo.compare EQ, %a, %a, FLOAT : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
%ne = stablehlo.compare NE, %a, %a, FLOAT : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %eq, %ne : tensor<2x2xi1>, tensor<2x2xi1>
}
}

// REMOVED: func.func @main(%arg0: tensor<2x2xf32>) -> (tensor<2x2xi1>, tensor<2x2xi1>) {
// REMOVED-NEXT: %0 = stablehlo.constant dense<true> : tensor<2x2xi1>
// REMOVED-NEXT: %1 = stablehlo.constant dense<false> : tensor<2x2xi1>
// REMOVED-NEXT: return %0, %1 : tensor<2x2xi1>, tensor<2x2xi1>
// REMOVED-NEXT: }

// SAME: func.func @main(%arg0: tensor<2x2xf32>) -> (tensor<2x2xi1>, tensor<2x2xi1>) {
// SAME-NEXT: %0 = stablehlo.compare EQ, %arg0, %arg0, FLOAT : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// SAME-NEXT: %1 = stablehlo.compare NE, %arg0, %arg0, FLOAT : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// SAME-NEXT: return %0, %1 : tensor<2x2xi1>, tensor<2x2xi1>
// SAME-NEXT: }

0 comments on commit 5c65020

Please sign in to comment.