From c30ab73a0ae9c4b8faf40f07e66ac99fd192b20d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 13 May 2024 11:04:03 +0000 Subject: [PATCH] [Substrait] Create `emit` op to represent output mapping. The commit introduces handling for the `emit_kind` field of the `RelCommon` message, which is a common field of all (?) cases of the `Rel` message. The current design models this field as a dedicated op such that the output mapping is only ever present in that op, `emit`. There are at least two alternatives to this IR design. The first one consists of making the output mapping part of each of the ops the represnet `Rel` messages and expose it through the `RelOpInterface`. However, this would mean that (1) the custom assembly of each op would have to represent the mapping, which is manual effort and a possible source for inconsistencies, (2) each op would have to implement type inference in the presence of a mapping, and (3) most rewrites of all ops would have to take that mapping into account for their semantics. Having the mapping in one place makes all of this simpler. The downside is that what is kept in a single place in the Substrait protobuf format is now spread across two ops in the MLIR representation. However, I believe that this is the smaller of two evils and the current import and export seems to work. Another alternative would be to combine the two: make the mapping part of all ops but *also* introduce a dedicated `emit` op. Then, two passes could move the mapping from one to the other depending on which of the two representations would be more convenient. However, this would not get rid of Issues (1) and (2) above and lead to more concepts and code. --- .../Dialect/Substrait/IR/SubstraitOps.td | 34 +++++ lib/Dialect/Substrait/IR/Substrait.cpp | 32 ++++ lib/Target/SubstraitPB/Export.cpp | 96 ++++++++++-- lib/Target/SubstraitPB/Import.cpp | 38 ++++- test/Dialect/Substrait/emit-invalid.mlir | 23 +++ test/Dialect/Substrait/emit.mlir | 48 ++++++ .../SubstraitPB/Export/emit-invalid.mlir | 28 ++++ test/Target/SubstraitPB/Export/emit.mlir | 58 ++++++++ test/Target/SubstraitPB/Import/emit.textpb | 137 ++++++++++++++++++ 9 files changed, 481 insertions(+), 13 deletions(-) create mode 100644 test/Dialect/Substrait/emit-invalid.mlir create mode 100644 test/Dialect/Substrait/emit.mlir create mode 100644 test/Target/SubstraitPB/Export/emit-invalid.mlir create mode 100644 test/Target/SubstraitPB/Export/emit.mlir create mode 100644 test/Target/SubstraitPB/Import/emit.textpb diff --git a/include/structured/Dialect/Substrait/IR/SubstraitOps.td b/include/structured/Dialect/Substrait/IR/SubstraitOps.td index 5cbaaf9bad0c..a1b408281c10 100644 --- a/include/structured/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/structured/Dialect/Substrait/IR/SubstraitOps.td @@ -247,6 +247,40 @@ def Substrait_CrossOp : Substrait_RelOp<"cross", [ }]; } +def Substrait_EmitOp : Substrait_RelOp<"emit", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let summary = "Projection (a.k.a. 'emit') as dedicated operation"; + let description = [{ + Represents the `Emit` message of the `emit_kind` field in the `RelCommon` + message. While projection is inlined into all relations in the protobuf + format, this op separates out this functionality in a dedicated op in order + to simplify rewriting. + + Example: + + ```mlir + %0 = ... + %1 = emit [2, 1] from %0 : tuple -> tuple + ``` + }]; + let arguments = (ins + Substrait_Relation:$input, + I64ArrayAttr:$mapping + ); + let results = (outs Substrait_Relation:$result); + let assemblyFormat = [{ + $mapping `from` $input attr-dict `:` type($input) `->` type($result) + }]; + let extraClassDefinition = [{ + /// Implement OpAsmOpInterface. + ::llvm::StringRef $cppClass::getDefaultDialect() { + return SubstraitDialect::getDialectNamespace(); + } + }]; +} + def Substrait_FilterOp : Substrait_RelOp<"filter", [ SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">, DeclareOpInterfaceMethods, diff --git a/lib/Dialect/Substrait/IR/Substrait.cpp b/lib/Dialect/Substrait/IR/Substrait.cpp index 7475dcc78340..a6c8b369b45f 100644 --- a/lib/Dialect/Substrait/IR/Substrait.cpp +++ b/lib/Dialect/Substrait/IR/Substrait.cpp @@ -72,6 +72,38 @@ CrossOp::inferReturnTypes(MLIRContext *context, std::optional loc, return success(); } +LogicalResult +EmitOp::inferReturnTypes(MLIRContext *context, std::optional loc, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + auto *typedProperties = properties.as(); + if (!loc) + loc = UnknownLoc::get(context); + + ArrayAttr mapping = typedProperties->getMapping(); + Type inputType = operands[0].getType(); + ArrayRef inputTypes = inputType.cast().getTypes(); + + // Map input types to output types. + SmallVector outputTypes; + outputTypes.reserve(mapping.size()); + for (auto indexAttr : mapping.getAsRange()) { + int64_t index = indexAttr.getInt(); + if (index < 0 || index >= static_cast(inputTypes.size())) + return ::emitError(loc.value()) + << index << " is not a valid index into " << inputType; + Type mappedType = inputTypes[index]; + outputTypes.push_back(mappedType); + } + + // Create final tuple type. + auto outputType = TupleType::get(context, outputTypes); + inferredReturnTypes.push_back(outputType); + + return success(); +} + /// Computes the type of the nested field of the given `type` identified by /// `position`. Each entry `n` in the given index array `position` corresponds /// to the `n`-th entry in that level. The function is thus implemented diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index 36e803153590..b192595cd85f 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -39,6 +39,7 @@ namespace { static FailureOr> exportOperation(OP_TYPE op); DECLARE_EXPORT_FUNC(CrossOp, Rel) +DECLARE_EXPORT_FUNC(EmitOp, Rel) DECLARE_EXPORT_FUNC(ExpressionOpInterface, Expression) DECLARE_EXPORT_FUNC(FieldReferenceOp, Expression) DECLARE_EXPORT_FUNC(FilterOp, Rel) @@ -51,6 +52,65 @@ DECLARE_EXPORT_FUNC(RelOpInterface, Rel) FailureOr> exportOperation(Operation *op); FailureOr> exportOperation(RelOpInterface op); +/// Creates the `RelCommon` message with the `emit_kind` field for the given +/// op. +/// +/// **This function has to be called during the export of every `Rel` case +/// that has a `RelCommon` message.** +/// +/// If the result produced by the gien op is an `EmitOp`, then the returned +/// `RelCommon` message contains an `Emit` message that represents the +/// output mapping of the `EmitOp`. Otherwise, the returned `RelCommon` +/// message contains a `Direct` message. +/// +/// If there is more than one `EmitOp` user or some `EmitOp` users and some +/// other users, then an error is returned because these cases can't be +/// expressed by a single `Emit` message. Some corner cases where export +/// might still be possible are cases with multiple `EmitOp`s that are all +/// identical and a subset of `EmitOp` users all with identity mappings. All +/// of these should go away through canonicalization and/or CSE. +FailureOr> +createRelComminWithEmit(RelOpInterface op) { + auto relCommon = std::make_unique(); + + Value result = op->getResult(0); + + // Collect all `EmitOp`s that use the result of this operation. + SmallVector emitOps; + size_t numUsers = 0; + for (Operation *user : result.getUsers()) { + numUsers++; + if (isa(user)) + emitOps.push_back(dyn_cast(user)); + } + + // If we don't have an `EmitOp` user, then `op` has direct emit behavior. + if (emitOps.empty()) { + auto direct = std::make_unique(); + relCommon->set_allocated_direct(direct.release()); + return relCommon; + } + + // If we have more that one `EmitOp` user or fewer `EmitOp` users than total + // users, then some mappings are different from the others, which can't be + // expressed by a single `Emit` message. + if (emitOps.size() > 1 || emitOps.size() != numUsers) { + return op->emitOpError("is consumed by different emit ops (try running " + "canonicalization and/or CSE)"); + } + + // Normal case: we have exactly one `EmitOp` user. + EmitOp emitOp = emitOps.front(); + + // Build the `Emit` message. + auto emit = std::make_unique(); + for (auto intAttr : emitOp.getMapping().getAsRange()) + emit->add_output_mapping(intAttr.getInt()); + + relCommon->set_allocated_emit(emit.release()); + return relCommon; +} + FailureOr> exportType(Location loc, mlir::Type mlirType) { MLIRContext *context = mlirType.getContext(); @@ -141,6 +201,20 @@ FailureOr> exportOperation(CrossOp op) { return rel; } +/// We just forward to the overload for `RelOpInterface`, which will have to +/// export this op. We can't (easily) do it here because the emit op is +/// represented as part of the `RelCommon` message of one of the cases of the +/// `Rel` message but there is no generic way to access the `common` field of +/// the various cases. +FailureOr> exportOperation(EmitOp op) { + auto inputOp = + dyn_cast_if_present(op.getInput().getDefiningOp()); + if (!inputOp) + return op->emitOpError("input was not produced by Substrait relation op"); + + return exportOperation(inputOp); +} + FailureOr> exportOperation(ExpressionOpInterface op) { return llvm::TypeSwitch>>( @@ -208,10 +282,12 @@ FailureOr> exportOperation(FieldReferenceOp op) { } FailureOr> exportOperation(FilterOp op) { - // Build `RelCommon` message. - auto relCommon = std::make_unique(); - auto direct = std::make_unique(); - relCommon->set_allocated_direct(direct.release()); + // Build `RelCommon` message with emit mapping. + FailureOr> maybeRelCommon = + createRelComminWithEmit(op); + if (failed(maybeRelCommon)) + return failure(); + auto relCommon = std::move(maybeRelCommon.value()); // Build input `Rel` message. auto inputOp = @@ -299,10 +375,12 @@ FailureOr> exportOperation(NamedTableOp op) { namedTable->add_names(attr.getLeafReference().str()); } - // Build `RelCommon` message. - auto relCommon = std::make_unique(); - auto direct = std::make_unique(); - relCommon->set_allocated_direct(direct.release()); + // Build `RelCommon` message with emit mapping. + FailureOr> maybeRelCommon = + createRelComminWithEmit(op); + if (failed(maybeRelCommon)) + return failure(); + auto relCommon = std::move(maybeRelCommon.value()); // Build `Struct` message. auto struct_ = std::make_unique(); @@ -368,7 +446,7 @@ FailureOr> exportOperation(PlanOp op) { FailureOr> exportOperation(RelOpInterface op) { return llvm::TypeSwitch>>(op) - .Case( + .Case( [&](auto op) { return exportOperation(op); }) .Default([](auto op) { op->emitOpError("not supported for export"); diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index 905ab2d64903..ace8c71d7304 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -50,12 +50,42 @@ DECLARE_IMPORT_FUNC(Expression, Expression, ExpressionOpInterface) DECLARE_IMPORT_FUNC(FieldReference, Expression::FieldReference, FieldReferenceOp) DECLARE_IMPORT_FUNC(Literal, Expression::Literal, LiteralOp) -DECLARE_IMPORT_FUNC(NamedTable, Rel, NamedTableOp) +DECLARE_IMPORT_FUNC(NamedTable, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(Plan, Plan, PlanOp) DECLARE_IMPORT_FUNC(PlanRel, PlanRel, PlanRelOp) DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface) +/// Imports the provided `RelCommon` message by producing an `EmitOp` that +/// expresses the `Emit` message if it exists. +/// +/// **This function must be called at the end of the import function of every +/// `Rel` message with a `RelCommon` message.** +/// +/// The provided `inputOp` is the op that was imported by the `Rel` message +/// containing the provided `RelCommon` message. The function returns the +/// `EmitOp` if one was created and the `inputOp` otherwise. +static mlir::FailureOr +importMaybeEmit(ImplicitLocOpBuilder builder, const RelCommon &message, + RelOpInterface inputOp) { + // For the `direct`, we just forward the input op. + if (message.has_direct()) + return inputOp; + assert(message.has_emit() && "expected either 'direct' or 'emit'"); + + // For the `emit` case, we need to insert an `EmitOp`. + const proto::RelCommon::Emit &emit = message.emit(); + SmallVector mapping; + for (int64_t index : emit.output_mapping()) + mapping.push_back(index); + ArrayAttr mappingAttr = builder.getI64ArrayAttr(mapping); + + Value input = inputOp->getResult(0); + auto emitOp = builder.create(input, mappingAttr); + + return cast(emitOp.getOperation()); +} + static mlir::FailureOr importType(MLIRContext *context, const proto::Type &type) { @@ -246,10 +276,10 @@ static mlir::FailureOr importFilterRel(ImplicitLocOpBuilder builder, builder.create(conditionOp.value()->getResult(0)); } - return filterOp; + return importMaybeEmit(builder, filterRel.common(), filterOp); } -static mlir::FailureOr +static mlir::FailureOr importNamedTable(ImplicitLocOpBuilder builder, const Rel &message) { const ReadRel &readRel = message.read(); const ReadRel::NamedTable &namedTable = readRel.named_table(); @@ -294,7 +324,7 @@ importNamedTable(ImplicitLocOpBuilder builder, const Rel &message) { auto namedTableOp = builder.create(resultType, tableName, fieldNamesAttr); - return namedTableOp; + return importMaybeEmit(builder, readRel.common(), namedTableOp); } static FailureOr importPlan(ImplicitLocOpBuilder builder, diff --git a/test/Dialect/Substrait/emit-invalid.mlir b/test/Dialect/Substrait/emit-invalid.mlir new file mode 100644 index 000000000000..c1dc844fab88 --- /dev/null +++ b/test/Dialect/Substrait/emit-invalid.mlir @@ -0,0 +1,23 @@ +// RUN: structured-opt -verify-diagnostics -split-input-file %s + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+2 {{'substrait.emit' op failed to infer returned types}} + // expected-error@+1 {{1 is not a valid index into 'tuple'}} + %1 = emit [1] from %0 : tuple -> tuple + yield %1 : tuple + } +} + +// ----- + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+2 {{'substrait.emit' op failed to infer returned types}} + // expected-error@+1 {{-1 is not a valid index into 'tuple'}} + %1 = emit [-1] from %0 : tuple -> tuple + yield %1 : tuple + } +} diff --git a/test/Dialect/Substrait/emit.mlir b/test/Dialect/Substrait/emit.mlir new file mode 100644 index 000000000000..6689c3798b7b --- /dev/null +++ b/test/Dialect/Substrait/emit.mlir @@ -0,0 +1,48 @@ +// RUN: structured-opt -split-input-file %s \ +// RUN: | FileCheck %s + +// CHECK-LABEL: substrait.plan +// CHECK-NEXT: relation +// CHECK-NEXT: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = emit [1, 0] from %[[V0]] : +// CHECK-SAME: tuple -> tuple + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a", "b"] : tuple + %1 = emit [1, 0] from %0 : tuple -> tuple + yield %1 : tuple + } +} + +// ----- + +// CHECK-LABEL: substrait.plan +// CHECK-NEXT: relation +// CHECK-NEXT: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = emit [0, 0] from %[[V0]] : +// CHECK-SAME: tuple -> tuple + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = emit [0, 0] from %0 : tuple -> tuple + yield %1 : tuple + } +} + +// ----- + +// CHECK-LABEL: substrait.plan +// CHECK-NEXT: relation +// CHECK-NEXT: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = emit [1] from %[[V0]] : +// CHECK-SAME: tuple -> tuple + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a", "b"] : tuple + %1 = emit [1] from %0 : tuple -> tuple + yield %1 : tuple + } +} diff --git a/test/Target/SubstraitPB/Export/emit-invalid.mlir b/test/Target/SubstraitPB/Export/emit-invalid.mlir new file mode 100644 index 000000000000..d740aa42bd71 --- /dev/null +++ b/test/Target/SubstraitPB/Export/emit-invalid.mlir @@ -0,0 +1,28 @@ +// RUN: structured-translate -verify-diagnostics -split-input-file %s \ +// RUN: -substrait-to-protobuf + +// Two different `emit` consumers: can't export into a single `Emit` message. + +substrait.plan version 0 : 42 : 1 { + relation { + // expected-error@+1 {{'substrait.named_table' op is consumed by different emit ops (try running canonicalization and/or CSE)}} + %0 = named_table @t1 as ["a", "b"] : tuple + %1 = emit [1, 0] from %0 : tuple -> tuple + %2 = emit [0, 1] from %0 : tuple -> tuple + yield %1 : tuple + } +} + +// ----- + +// One `emit` consumer, one other consumer: can't export to a single `Emit` +// message. + +substrait.plan version 0 : 42 : 1 { + relation { + // expected-error@+1 {{'substrait.named_table' op is consumed by different emit ops (try running canonicalization and/or CSE)}} + %0 = named_table @t1 as ["a", "b"] : tuple + %1 = emit [1, 0] from %0 : tuple -> tuple + yield %0 : tuple + } +} diff --git a/test/Target/SubstraitPB/Export/emit.mlir b/test/Target/SubstraitPB/Export/emit.mlir new file mode 100644 index 000000000000..ff0c28373c1a --- /dev/null +++ b/test/Target/SubstraitPB/Export/emit.mlir @@ -0,0 +1,58 @@ +// RUN: structured-translate -substrait-to-protobuf --split-input-file %s \ +// RUN: | FileCheck %s + +// RUN: structured-translate -substrait-to-protobuf %s \ +// RUN: --split-input-file --output-split-marker="# -----" \ +// RUN: | structured-translate -protobuf-to-substrait \ +// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \ +// RUN: | structured-translate -substrait-to-protobuf \ +// RUN: --split-input-file --output-split-marker="# -----" \ +// RUN: | FileCheck %s + + +// Checks that the `emit` field of a `named_table` is exported correctly. + +// CHECK-LABEL: relations { +// CHECK-NEXT: rel { +// CHECK-NEXT: read { +// CHECK-NEXT: common { +// CHECK-NEXT: emit { +// CHECK-NEXT: output_mapping: 1 +// CHECK-NEXT: } + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a", "b"] : tuple + %1 = emit [1] from %0 : tuple -> tuple + yield %1 : tuple + } +} + +// ----- + +// Checks that the `emit` field of a `named_table` is exported correctly. + +// CHECK-LABEL: relations { +// CHECK-NEXT: rel { +// CHECK-NEXT: filter { +// CHECK-NEXT: common { +// CHECK-NEXT: emit { +// CHECK-NEXT: output_mapping: 1 +// CHECK-NEXT: } +// CHECK-LABEL: input { +// CHECK-NEXT: read { +// CHECK-NEXT: common { +// CHECK-NEXT: direct + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a", "b"] : tuple + %1 = filter %0 : tuple { + ^bb0(%arg : tuple): + %2 = literal -1 : si1 + yield %2 : si1 + } + %2 = emit [1] from %1 : tuple -> tuple + yield %2 : tuple + } +} diff --git a/test/Target/SubstraitPB/Import/emit.textpb b/test/Target/SubstraitPB/Import/emit.textpb new file mode 100644 index 000000000000..30bf2104d0e0 --- /dev/null +++ b/test/Target/SubstraitPB/Import/emit.textpb @@ -0,0 +1,137 @@ +# RUN: structured-translate -protobuf-to-substrait %s \ +# RUN: --split-input-file="# ""-----" \ +# RUN: | FileCheck %s + +# RUN: structured-translate -protobuf-to-substrait %s \ +# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \ +# RUN: | structured-translate -substrait-to-protobuf \ +# RUN: --split-input-file --output-split-marker="# ""-----" \ +# RUN: | structured-translate -protobuf-to-substrait \ +# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \ +# RUN: | FileCheck %s + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: relation +# CHECK-NEXT: %[[V0:.*]] = named_table +# CHECK-NEXT: %[[V1:.*]] = emit [1, 0] from %[[V0]] +# CHECK-NEXT: yield %[[V1]] + +relations { + rel { + read { + common { + emit { + output_mapping: 1 + output_mapping: 0 + } + } + base_schema { + names: "a" + names: "b" + struct { + types { + bool { + nullability: NULLABILITY_REQUIRED + } + } + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +} + +# ----- + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: relation +# CHECK-NEXT: %[[V0:.*]] = named_table +# CHECK-NEXT: %[[V1:.*]] = emit [0, 0] from %[[V0]] +# CHECK-NEXT: yield %[[V1]] + +relations { + rel { + read { + common { + emit { + output_mapping: 0 + output_mapping: 0 + } + } + base_schema { + names: "a" + struct { + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +} + +# ----- + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: relation +# CHECK-NEXT: %[[V0:.*]] = named_table +# CHECK-NEXT: %[[V1:.*]] = emit [1] from %[[V0]] +# CHECK-NEXT: yield %[[V1]] + +relations { + rel { + read { + common { + emit { + output_mapping: 1 + } + } + base_schema { + names: "a" + names: "b" + struct { + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + types { + bool { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +}