Skip to content

Commit

Permalink
[Substrait] Extend emit deduplication to project op. (#836)
Browse files Browse the repository at this point in the history
This pushes duplicate fields from `emit` ops producing the operand of a
`project` op to an emit op on the result of the `project`.

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net authored Jul 23, 2024
1 parent 138086d commit d9138c6
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 2 deletions.
74 changes: 72 additions & 2 deletions lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,85 @@ struct PushDuplicatesThroughFilterPattern : public OpRewritePattern<FilterOp> {
}
};

/// Pushes duplicates in the mappings of `emit` ops producing the input through
/// the `filter` op. This works by introducing a new `emit` op without the
/// duplicates, creating a new `filter` op updated to work on the deduplicated
/// element type, and finally a new `emit` op that maps back to the original
/// order.
struct PushDuplicatesThroughProjectPattern
: public OpRewritePattern<ProjectOp> {
using OpRewritePattern<ProjectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ProjectOp op,
PatternRewriter &rewriter) const override {
auto emitOp = op.getInput().getDefiningOp<EmitOp>();
if (!emitOp)
return rewriter.notifyMatchFailure(
op, "input operand is not produced by an 'emit' op");

// Create input ops for the new `project` op. These may be the original
// inputs or `emit` ops that remove duplicates.
SmallVector<int64_t> reverseMapping;
auto [newInput, numDedupIndices, hasDuplicates] =
createDeduplicatingEmit(op.getInput(), reverseMapping, rewriter);

if (!hasDuplicates)
// Note: if we end up failing here, then the invokation of
// `createDeduplicatingEmit` returned without creating a new (`emit`) op.
return rewriter.notifyMatchFailure(
op, "the 'emit' input does not have duplicates");

MLIRContext *context = op.getContext();

// Compute deduplicated output field types.
Operation *terminator = op.getExpressions().front().getTerminator();
auto newInputTupleType = cast<TupleType>(newInput.getType());

SmallVector<Type> outputTypes;
outputTypes.reserve(newInputTupleType.size() +
terminator->getNumOperands());
append_range(outputTypes, newInputTupleType.getTypes());
append_range(outputTypes, terminator->getOperandTypes());
auto newOutputType = TupleType::get(context, outputTypes);

// Create new `project` op. Move over the `expressions` region. This needs
// to happen now because replacing the op will destroy the region.
auto newOp =
rewriter.create<ProjectOp>(op.getLoc(), newOutputType, newInput);
rewriter.inlineRegionBefore(op.getExpressions(), newOp.getExpressions(),
newOp.getExpressions().end());

// Update the `condition` region.
deduplicateRegionArgs(newOp.getExpressions(), emitOp.getMapping(),
newInput.getType(), rewriter);

// Compute output indices for the expressions added by the region.
int64_t numTotalIndices = numDedupIndices + terminator->getNumOperands();
append_range(reverseMapping, seq(numDedupIndices, numTotalIndices));

// Replace the old `project` op with a new `emit` op that maps back to the
// original emit order.
ArrayAttr reverseMappingAttr = rewriter.getI64ArrayAttr(reverseMapping);
rewriter.replaceOpWithNewOp<EmitOp>(op, newOp, reverseMappingAttr);

return failure();
}
};

} // namespace

namespace mlir {
namespace substrait {

void populateEmitDeduplicationPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<PushDuplicatesThroughCrossPattern,
PushDuplicatesThroughFilterPattern>(context);
patterns.add<
// clang-format off
PushDuplicatesThroughCrossPattern,
PushDuplicatesThroughFilterPattern,
PushDuplicatesThroughProjectPattern
// clang-format on
>(context);
}

std::unique_ptr<Pass> createEmitDeduplicationPass() {
Expand Down
34 changes: 34 additions & 0 deletions test/Transforms/Substrait/emit-deduplication.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,37 @@ substrait.plan version 0 : 42 : 1 {
yield %2 : tuple<si1, si1, tuple<si1, si32>, si1, si1>
}
}

// -----

// `project` op (`PushDuplicatesThroughProjectPattern`).

// CHECK-LABEL: substrait.plan
// CHECK-NEXT: relation
// CHECK-NEXT: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = emit [1] from %[[V0]] :
// CHECK-NEXT: %[[V2:.*]] = project %[[V1]] : tuple<si32> -> tuple<si32, si1> {
// CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: [[TYPE:.*]]):
// CHECK-NEXT: %[[V3:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]]
// CHECK-NEXT: %[[V4:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]]
// CHECK-NEXT: %[[V5:.*]] = func.call @f(%[[V3]], %[[V4]]) :
// CHECK-NEXT: yield %[[V5]] : si1
// CHECK-NEXT: }
// CHECK-NEXT: %[[V6:.*]] = emit [0, 0, 1] from %[[V2]]

func.func private @f(si32, si32) -> si1

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a", "b"] : tuple<si1, si32>
%1 = emit [1, 1] from %0 : tuple<si1, si32> -> tuple<si32, si32>
%2 = project %1 : tuple<si32, si32> -> tuple<si32, si32, si1> {
^bb0(%arg : tuple<si32, si32>):
%3 = field_reference %arg[[0]] : tuple<si32, si32>
%4 = field_reference %arg[[1]] : tuple<si32, si32>
%5 = func.call @f(%3, %4) : (si32, si32) -> si1
yield %5 : si1
}
yield %2 : tuple<si32, si32, si1>
}
}

0 comments on commit d9138c6

Please sign in to comment.