From d9138c6a89f1342506f4414d021c9088adf13662 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 23 Jul 2024 12:15:19 +0200 Subject: [PATCH] [Substrait] Extend emit deduplication to project op. (#836) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This pushes duplicate fields from `emit` ops producing the operand of a `project` op to an emit op on the result of the `project`. Signed-off-by: Ingo Müller --- .../Transforms/EmitDeduplication.cpp | 74 ++++++++++++++++++- .../Substrait/emit-deduplication.mlir | 34 +++++++++ 2 files changed, 106 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp b/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp index e2fcf57c86ca..9e0c22247337 100644 --- a/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp +++ b/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp @@ -266,6 +266,71 @@ struct PushDuplicatesThroughFilterPattern : public OpRewritePattern { } }; +/// Pushes duplicates in the mappings of `emit` ops producing the input through +/// the `filter` op. This works by introducing a new `emit` op without the +/// duplicates, creating a new `filter` op updated to work on the deduplicated +/// element type, and finally a new `emit` op that maps back to the original +/// order. +struct PushDuplicatesThroughProjectPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ProjectOp op, + PatternRewriter &rewriter) const override { + auto emitOp = op.getInput().getDefiningOp(); + if (!emitOp) + return rewriter.notifyMatchFailure( + op, "input operand is not produced by an 'emit' op"); + + // Create input ops for the new `project` op. These may be the original + // inputs or `emit` ops that remove duplicates. + SmallVector reverseMapping; + auto [newInput, numDedupIndices, hasDuplicates] = + createDeduplicatingEmit(op.getInput(), reverseMapping, rewriter); + + if (!hasDuplicates) + // Note: if we end up failing here, then the invokation of + // `createDeduplicatingEmit` returned without creating a new (`emit`) op. + return rewriter.notifyMatchFailure( + op, "the 'emit' input does not have duplicates"); + + MLIRContext *context = op.getContext(); + + // Compute deduplicated output field types. + Operation *terminator = op.getExpressions().front().getTerminator(); + auto newInputTupleType = cast(newInput.getType()); + + SmallVector outputTypes; + outputTypes.reserve(newInputTupleType.size() + + terminator->getNumOperands()); + append_range(outputTypes, newInputTupleType.getTypes()); + append_range(outputTypes, terminator->getOperandTypes()); + auto newOutputType = TupleType::get(context, outputTypes); + + // Create new `project` op. Move over the `expressions` region. This needs + // to happen now because replacing the op will destroy the region. + auto newOp = + rewriter.create(op.getLoc(), newOutputType, newInput); + rewriter.inlineRegionBefore(op.getExpressions(), newOp.getExpressions(), + newOp.getExpressions().end()); + + // Update the `condition` region. + deduplicateRegionArgs(newOp.getExpressions(), emitOp.getMapping(), + newInput.getType(), rewriter); + + // Compute output indices for the expressions added by the region. + int64_t numTotalIndices = numDedupIndices + terminator->getNumOperands(); + append_range(reverseMapping, seq(numDedupIndices, numTotalIndices)); + + // Replace the old `project` op with a new `emit` op that maps back to the + // original emit order. + ArrayAttr reverseMappingAttr = rewriter.getI64ArrayAttr(reverseMapping); + rewriter.replaceOpWithNewOp(op, newOp, reverseMappingAttr); + + return failure(); + } +}; + } // namespace namespace mlir { @@ -273,8 +338,13 @@ namespace substrait { void populateEmitDeduplicationPatterns(RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(context); + patterns.add< + // clang-format off + PushDuplicatesThroughCrossPattern, + PushDuplicatesThroughFilterPattern, + PushDuplicatesThroughProjectPattern + // clang-format on + >(context); } std::unique_ptr createEmitDeduplicationPass() { diff --git a/test/Transforms/Substrait/emit-deduplication.mlir b/test/Transforms/Substrait/emit-deduplication.mlir index 85af2eac6f66..1774cc4e0ac1 100644 --- a/test/Transforms/Substrait/emit-deduplication.mlir +++ b/test/Transforms/Substrait/emit-deduplication.mlir @@ -197,3 +197,37 @@ substrait.plan version 0 : 42 : 1 { yield %2 : tuple, si1, si1> } } + +// ----- + +// `project` op (`PushDuplicatesThroughProjectPattern`). + +// CHECK-LABEL: substrait.plan +// CHECK-NEXT: relation +// CHECK-NEXT: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = emit [1] from %[[V0]] : +// CHECK-NEXT: %[[V2:.*]] = project %[[V1]] : tuple -> tuple { +// CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: [[TYPE:.*]]): +// CHECK-NEXT: %[[V3:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]] +// CHECK-NEXT: %[[V4:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]] +// CHECK-NEXT: %[[V5:.*]] = func.call @f(%[[V3]], %[[V4]]) : +// CHECK-NEXT: yield %[[V5]] : si1 +// CHECK-NEXT: } +// CHECK-NEXT: %[[V6:.*]] = emit [0, 0, 1] from %[[V2]] + +func.func private @f(si32, si32) -> si1 + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a", "b"] : tuple + %1 = emit [1, 1] from %0 : tuple -> tuple + %2 = project %1 : tuple -> tuple { + ^bb0(%arg : tuple): + %3 = field_reference %arg[[0]] : tuple + %4 = field_reference %arg[[1]] : tuple + %5 = func.call @f(%3, %4) : (si32, si32) -> si1 + yield %5 : si1 + } + yield %2 : tuple + } +}