diff --git a/include/structured/Dialect/Substrait/IR/SubstraitOps.td b/include/structured/Dialect/Substrait/IR/SubstraitOps.td index 8b8f09b6b86a..6078c1c4f5e4 100644 --- a/include/structured/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/structured/Dialect/Substrait/IR/SubstraitOps.td @@ -395,6 +395,7 @@ def Substrait_ProjectOp : Substrait_RelOp<"project", [ $input attr-dict `:` type($input) `->` type($result) $expressions }]; let hasRegionVerifier = 1; + let hasFolder = 1; let extraClassDefinition = [{ /// Implement OpAsmOpInterface. ::llvm::StringRef $cppClass::getDefaultDialect() { diff --git a/lib/Dialect/Substrait/IR/Substrait.cpp b/lib/Dialect/Substrait/IR/Substrait.cpp index ed7d1c4020b2..3925d0a4c8bf 100644 --- a/lib/Dialect/Substrait/IR/Substrait.cpp +++ b/lib/Dialect/Substrait/IR/Substrait.cpp @@ -339,6 +339,17 @@ LogicalResult PlanRelOp::verifyRegions() { return verifyNamedStruct(getOperation(), fieldNames, tupleType); } +OpFoldResult ProjectOp::fold(FoldAdaptor adaptor) { + Operation *terminator = adaptor.getExpressions().front().getTerminator(); + + // If the region does not yield any values, the the `project` has no effect. + if (terminator->getNumOperands() == 0) { + return getInput(); + } + + return {}; +} + LogicalResult ProjectOp::verifyRegions() { // Verify that the expression block has a matching argument type. auto inputTupleType = llvm::cast(getInput().getType()); diff --git a/test/Dialect/Substrait/canonicalize.mlir b/test/Dialect/Substrait/canonicalize.mlir index cf1e748d3e20..79ed18a8e725 100644 --- a/test/Dialect/Substrait/canonicalize.mlir +++ b/test/Dialect/Substrait/canonicalize.mlir @@ -72,3 +72,22 @@ substrait.plan version 0 : 42 : 1 { yield %5 : tuple } } + +// ----- + +// Check that empty `project` folded. + +// CHECK-LABEL: substrait.plan +// CHECK-NEXT: relation +// CHECK-NEXT: %[[V0:.*]] = named_table +// CHECK-NEXT: yield %[[V0]] + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = project %0 : tuple -> tuple { + ^bb0(%arg0: tuple): + } + yield %1 : tuple + } +}