diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 36df6f36..5dbf3336 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -1942,7 +1942,7 @@ struct IfOpEnzymeOpsRemover removalBlockExplore(falseBlock, falseMapping, builder, gradients, pushedCaches); - if (gradients.size() == 0 || pushedCaches.size() == 0) + if (gradients.size() == 0 && pushedCaches.size() == 0) return success(); Operation *trueTerm = trueBlock->getTerminator(); diff --git a/test/lit_tests/diffrules/stablehlo/if_remove.mlir b/test/lit_tests/diffrules/stablehlo/if_remove.mlir index 3ddaa3a0..7c3effd6 100644 --- a/test/lit_tests/diffrules/stablehlo/if_remove.mlir +++ b/test/lit_tests/diffrules/stablehlo/if_remove.mlir @@ -1,4 +1,4 @@ -// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_active,enzyme_const retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --enzyme-hlo-opt | FileCheck %s --check-prefix=REVERSE +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_active,enzyme_const retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --enzyme-hlo-opt --allow-unregistered-dialect | FileCheck %s --check-prefix=REVERSE module { func.func @main(%arg0: tensor<10xf32>, %pred: tensor) -> tensor<10xf32> { @@ -14,8 +14,36 @@ module { return %0 : tensor<10xf32> } + func.func @zmain2(%arg0: tensor<10xf32>, %arg1: tensor, %arg2: tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>) { + %cst = stablehlo.constant dense<1.000000e+00> : tensor<10xf32> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<10xf32> + %0 = "enzyme.init"() : () -> !enzyme.Gradient> + %8 = "stablehlo.if"(%arg1) ({ + "enzyme.set"(%0, %cst_0) : (!enzyme.Gradient>, tensor<10xf32>) -> () + %13 = "test.foo"(%arg0) : (tensor<10xf32>) -> tensor<10xf32> + stablehlo.return %13 : tensor<10xf32> + }, { + "enzyme.set"(%0, %arg2) : (!enzyme.Gradient>, tensor<10xf32>) -> () + stablehlo.return %cst : tensor<10xf32> + }) : (tensor) -> tensor<10xf32> + %9 = "enzyme.get"(%0) : (!enzyme.Gradient>) -> tensor<10xf32> + return %8, %9 : tensor<10xf32>, tensor<10xf32> + } } +// REVERSE: func.func @zmain2(%arg0: tensor<10xf32>, %arg1: tensor, %arg2: tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>) { +// REVERSE-NEXT: %cst = stablehlo.constant dense<1.000000e+00> : tensor<10xf32> +// REVERSE-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf32> +// REVERSE-NEXT: %0 = stablehlo.select %arg1, %cst_0, %arg2 : tensor, tensor<10xf32> +// REVERSE-NEXT: %1 = "stablehlo.if"(%arg1) ({ +// REVERSE-NEXT: %2 = "test.foo"(%arg0) : (tensor<10xf32>) -> tensor<10xf32> +// REVERSE-NEXT: stablehlo.return %2 : tensor<10xf32> +// REVERSE-NEXT: }, { +// REVERSE-NEXT: stablehlo.return %cst : tensor<10xf32> +// REVERSE-NEXT: }) : (tensor) -> tensor<10xf32> +// REVERSE-NEXT: return %1, %0 : tensor<10xf32>, tensor<10xf32> +// REVERSE-NEXT: } + // REVERSE: func.func @main(%arg0: tensor<10xf32>, %arg1: tensor, %arg2: tensor<10xf32>) -> tensor<10xf32> { // REVERSE-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf32> // REVERSE-NEXT: %0 = "stablehlo.if"(%arg1) ({