diff --git a/lib/Dialect/Substrait/Transforms/CMakeLists.txt b/lib/Dialect/Substrait/Transforms/CMakeLists.txt index e4844e78e42c..ba4eef49b48c 100644 --- a/lib/Dialect/Substrait/Transforms/CMakeLists.txt +++ b/lib/Dialect/Substrait/Transforms/CMakeLists.txt @@ -9,5 +9,6 @@ add_mlir_dialect_library(MLIRSubstraitTransforms MLIRPass MLIRRewrite MLIRSubstraitDialect + MLIRTransforms MLIRTransformUtils ) diff --git a/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp b/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp index 9e0c22247337..35abb3fc868a 100644 --- a/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp +++ b/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp @@ -8,7 +8,9 @@ #include "structured/Dialect/Substrait/Transforms/Passes.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "structured/Dialect/Substrait/IR/Substrait.h" @@ -257,6 +259,13 @@ struct PushDuplicatesThroughFilterPattern : public OpRewritePattern { deduplicateRegionArgs(newOp.getCondition(), emitOp.getMapping(), newInput.getType(), rewriter); + // Deduplicating block args may create common subexpressions. Eliminate + // them immediately. + { + DominanceInfo domInfo; + mlir::eliminateCommonSubExpressions(rewriter, domInfo, newOp); + } + // Replace the old `filter` op with a new `emit` op that maps back to the // original emit order. ArrayAttr reverseMappingAttr = rewriter.getI64ArrayAttr(reverseMapping); @@ -318,6 +327,13 @@ struct PushDuplicatesThroughProjectPattern deduplicateRegionArgs(newOp.getExpressions(), emitOp.getMapping(), newInput.getType(), rewriter); + // Deduplicating block args may create common subexpressions. Eliminate + // them immediately. + { + DominanceInfo domInfo; + mlir::eliminateCommonSubExpressions(rewriter, domInfo, newOp); + } + // Compute output indices for the expressions added by the region. int64_t numTotalIndices = numDedupIndices + terminator->getNumOperands(); append_range(reverseMapping, seq(numDedupIndices, numTotalIndices)); diff --git a/test/Transforms/Substrait/emit-deduplication.mlir b/test/Transforms/Substrait/emit-deduplication.mlir index 1774cc4e0ac1..6c423d27919c 100644 --- a/test/Transforms/Substrait/emit-deduplication.mlir +++ b/test/Transforms/Substrait/emit-deduplication.mlir @@ -161,13 +161,11 @@ substrait.plan version 0 : 42 : 1 { // CHECK-NEXT: %[[V2:.*]] = filter %[[V1]] : {{.*}} { // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: [[TYPE:.*]]): // CHECK-NEXT: %[[V3:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]] -// CHECK-NEXT: %[[V4:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]] // CHECK-NEXT: %[[V5:.*]] = field_reference %[[ARG0]]{{\[}}[1, 0]] : [[TYPE]] // CHECK-NEXT: %[[V6:.*]] = field_reference %[[ARG0]]{{\[}}[1]] : [[TYPE]] // CHECK-NEXT: %[[V7:.*]] = field_reference %[[V6]]{{\[}}[1]] : -// CHECK-NEXT: %[[V8:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]] // CHECK-NEXT: %[[V9:.*]] = field_reference %[[ARG0]]{{\[}}[2]] : [[TYPE]] -// CHECK-NEXT: %[[Va:.*]] = func.call @f(%[[V3]], %[[V4]], %[[V5]], %[[V7]], %[[V8]], %[[V9]]) +// CHECK-NEXT: %[[Va:.*]] = func.call @f(%[[V3]], %[[V3]], %[[V5]], %[[V7]], %[[V3]], %[[V9]]) // CHECK-NEXT: yield %[[Va]] : si1 // CHECK-NEXT: } // CHECK-NEXT: %[[Vb:.*]] = emit [0, 0, 1, 0, 2] from %[[V2]] @@ -209,8 +207,7 @@ substrait.plan version 0 : 42 : 1 { // CHECK-NEXT: %[[V2:.*]] = project %[[V1]] : tuple -> tuple { // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: [[TYPE:.*]]): // CHECK-NEXT: %[[V3:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]] -// CHECK-NEXT: %[[V4:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]] -// CHECK-NEXT: %[[V5:.*]] = func.call @f(%[[V3]], %[[V4]]) : +// CHECK-NEXT: %[[V5:.*]] = func.call @f(%[[V3]], %[[V3]]) : // CHECK-NEXT: yield %[[V5]] : si1 // CHECK-NEXT: } // CHECK-NEXT: %[[V6:.*]] = emit [0, 0, 1] from %[[V2]]