Skip to content

Commit

Permalink
[Substrait] Create emit op to represent output mapping.
Browse files Browse the repository at this point in the history
The commit introduces handling for the `emit_kind` field of the
`RelCommon` message, which is a common field of all (?) cases of the
`Rel` message. The current design models this field as a dedicated op
such that the output mapping is only ever present in that op, `emit`.

There are at least two alternatives to this IR design. The first one
consists of making the output mapping part of each of the ops the
represnet `Rel` messages and expose it through the `RelOpInterface`.
However, this would mean that (1) the custom assembly of each op would
have to represent the mapping, which is manual effort and a possible
source for inconsistencies, (2) each op would have to implement type
inference in the presence of a mapping, and (3) most rewrites of all ops
would have to take that mapping into account for their semantics. Having
the mapping in one place makes all of this simpler. The downside is that
what is kept in a single place in the Substrait protobuf format is now
spread across two ops in the MLIR representation. However, I believe
that this is the smaller of two evils and the current import and export
seems to work.

Another alternative would be to combine the two: make the mapping part
of all ops but *also* introduce a dedicated `emit` op. Then, two passes
could move the mapping from one to the other depending on which of the
two representations would be more convenient. However, this would not
get rid of Issues (1) and (2) above and lead to more concepts and code.

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net committed May 27, 2024
1 parent 6375d94 commit dfd6939
Show file tree
Hide file tree
Showing 9 changed files with 500 additions and 13 deletions.
34 changes: 34 additions & 0 deletions include/structured/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,40 @@ def Substrait_CrossOp : Substrait_RelOp<"cross", [
}];
}

def Substrait_EmitOp : Substrait_RelOp<"emit", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Projection (a.k.a. 'emit') as dedicated operation";
let description = [{
Represents the `Emit` message of the `emit_kind` field in the `RelCommon`
message. While projection is inlined into all relations in the protobuf
format, this op separates out this functionality in a dedicated op in order
to simplify rewriting.

Example:

```mlir
%0 = ...
%1 = emit [2, 1] from %0 : tuple<si32, s1, si32> -> tuple<si32, si1>
```
}];
let arguments = (ins
Substrait_Relation:$input,
I64ArrayAttr:$mapping
);
let results = (outs Substrait_Relation:$result);
let assemblyFormat = [{
$mapping `from` $input attr-dict `:` type($input) `->` type($result)
}];
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
::llvm::StringRef $cppClass::getDefaultDialect() {
return SubstraitDialect::getDialectNamespace();
}
}];
}

def Substrait_FilterOp : Substrait_RelOp<"filter", [
SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
Expand Down
32 changes: 32 additions & 0 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,38 @@ CrossOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
return success();
}

LogicalResult
EmitOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
auto *typedProperties = properties.as<Properties *>();
if (!loc)
loc = UnknownLoc::get(context);

ArrayAttr mapping = typedProperties->getMapping();
Type inputType = operands[0].getType();
ArrayRef<Type> inputTypes = inputType.cast<TupleType>().getTypes();

// Map input types to output types.
SmallVector<Type> outputTypes;
outputTypes.reserve(mapping.size());
for (auto indexAttr : mapping.getAsRange<IntegerAttr>()) {
int64_t index = indexAttr.getInt();
if (index < 0 || index >= static_cast<int64_t>(inputTypes.size()))
return ::emitError(loc.value())
<< index << " is not a valid index into " << inputType;
Type mappedType = inputTypes[index];
outputTypes.push_back(mappedType);
}

// Create final tuple type.
auto outputType = TupleType::get(context, outputTypes);
inferredReturnTypes.push_back(outputType);

return success();
}

/// Computes the type of the nested field of the given `type` identified by
/// `position`. Each entry `n` in the given index array `position` corresponds
/// to the `n`-th entry in that level. The function is thus implemented
Expand Down
101 changes: 92 additions & 9 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace {
static FailureOr<std::unique_ptr<MESSAGE_TYPE>> exportOperation(OP_TYPE op);

DECLARE_EXPORT_FUNC(CrossOp, Rel)
DECLARE_EXPORT_FUNC(EmitOp, Rel)
DECLARE_EXPORT_FUNC(ExpressionOpInterface, Expression)
DECLARE_EXPORT_FUNC(FieldReferenceOp, Expression)
DECLARE_EXPORT_FUNC(FilterOp, Rel)
Expand All @@ -51,6 +52,65 @@ DECLARE_EXPORT_FUNC(RelOpInterface, Rel)
FailureOr<std::unique_ptr<pb::Message>> exportOperation(Operation *op);
FailureOr<std::unique_ptr<Rel>> exportOperation(RelOpInterface op);

/// Creates the `RelCommon` message with the `emit_kind` field for the given
/// op.
///
/// **This function has to be called during the export of every `Rel` case
/// that has a `RelCommon` message.**
///
/// If the result produced by the gien op is an `EmitOp`, then the returned
/// `RelCommon` message contains an `Emit` message that represents the
/// output mapping of the `EmitOp`. Otherwise, the returned `RelCommon`
/// message contains a `Direct` message.
///
/// If there is more than one `EmitOp` user or some `EmitOp` users and some
/// other users, then an error is returned because these cases can't be
/// expressed by a single `Emit` message. Some corner cases where export
/// might still be possible are cases with multiple `EmitOp`s that are all
/// identical and a subset of `EmitOp` users all with identity mappings. All
/// of these should go away through canonicalization and/or CSE.
FailureOr<std::unique_ptr<RelCommon>>
createRelComminWithEmit(RelOpInterface op) {
auto relCommon = std::make_unique<RelCommon>();

Value result = op->getResult(0);

// Collect all `EmitOp`s that use the result of this operation.
SmallVector<EmitOp> emitOps;
size_t numUsers = 0;
for (Operation *user : result.getUsers()) {
numUsers++;
if (isa<EmitOp>(user))
emitOps.push_back(dyn_cast<EmitOp>(user));
}

// If we don't have an `EmitOp` user, then `op` has direct emit behavior.
if (emitOps.empty()) {
auto direct = std::make_unique<RelCommon::Direct>();
relCommon->set_allocated_direct(direct.release());
return relCommon;
}

// If we have more that one `EmitOp` user or fewer `EmitOp` users than total
// users, then some mappings are different from the others, which can't be
// expressed by a single `Emit` message.
if (emitOps.size() > 1 || emitOps.size() != numUsers) {
return op->emitOpError("is consumed by different emit ops (try running "
"canonicalization and/or CSE)");
}

// Normal case: we have exactly one `EmitOp` user.
EmitOp emitOp = emitOps.front();

// Build the `Emit` message.
auto emit = std::make_unique<RelCommon::Emit>();
for (auto intAttr : emitOp.getMapping().getAsRange<IntegerAttr>())
emit->add_output_mapping(intAttr.getInt());

relCommon->set_allocated_emit(emit.release());
return relCommon;
}

FailureOr<std::unique_ptr<proto::Type>> exportType(Location loc,
mlir::Type mlirType) {
MLIRContext *context = mlirType.getContext();
Expand Down Expand Up @@ -141,6 +201,25 @@ FailureOr<std::unique_ptr<Rel>> exportOperation(CrossOp op) {
return rel;
}

/// We just forward to the overload for `RelOpInterface`, which will have to
/// export this op. We can't (easily) do it here because the emit op is
/// represented as part of the `RelCommon` message of one of the cases of the
/// `Rel` message but there is no generic way to access the `common` field of
/// the various cases.
FailureOr<std::unique_ptr<Rel>> exportOperation(EmitOp op) {
auto inputOp =
dyn_cast_if_present<RelOpInterface>(op.getInput().getDefiningOp());
if (!inputOp)
return op->emitOpError("input was not produced by Substrait relation op");

if (dyn_cast<EmitOp>(inputOp.getOperation()))
return op->emitOpError(
"with input produced by 'substrait.emit' op not supported for export "
"(try running canonicalization)");

return exportOperation(inputOp);
}

FailureOr<std::unique_ptr<Expression>>
exportOperation(ExpressionOpInterface op) {
return llvm::TypeSwitch<Operation *, FailureOr<std::unique_ptr<Expression>>>(
Expand Down Expand Up @@ -208,10 +287,12 @@ FailureOr<std::unique_ptr<Expression>> exportOperation(FieldReferenceOp op) {
}

FailureOr<std::unique_ptr<Rel>> exportOperation(FilterOp op) {
// Build `RelCommon` message.
auto relCommon = std::make_unique<RelCommon>();
auto direct = std::make_unique<RelCommon::Direct>();
relCommon->set_allocated_direct(direct.release());
// Build `RelCommon` message with emit mapping.
FailureOr<std::unique_ptr<RelCommon>> maybeRelCommon =
createRelComminWithEmit(op);
if (failed(maybeRelCommon))
return failure();
auto relCommon = std::move(maybeRelCommon.value());

// Build input `Rel` message.
auto inputOp =
Expand Down Expand Up @@ -310,10 +391,12 @@ FailureOr<std::unique_ptr<Rel>> exportOperation(NamedTableOp op) {
namedTable->add_names(attr.getLeafReference().str());
}

// Build `RelCommon` message.
auto relCommon = std::make_unique<RelCommon>();
auto direct = std::make_unique<RelCommon::Direct>();
relCommon->set_allocated_direct(direct.release());
// Build `RelCommon` message with emit mapping.
FailureOr<std::unique_ptr<RelCommon>> maybeRelCommon =
createRelComminWithEmit(op);
if (failed(maybeRelCommon))
return failure();
auto relCommon = std::move(maybeRelCommon.value());

// Build `Struct` message.
auto struct_ = std::make_unique<proto::Type::Struct>();
Expand Down Expand Up @@ -392,7 +475,7 @@ FailureOr<std::unique_ptr<Plan>> exportOperation(PlanOp op) {

FailureOr<std::unique_ptr<Rel>> exportOperation(RelOpInterface op) {
return llvm::TypeSwitch<Operation *, FailureOr<std::unique_ptr<Rel>>>(op)
.Case<CrossOp, FieldReferenceOp, FilterOp, NamedTableOp>(
.Case<CrossOp, EmitOp, FieldReferenceOp, FilterOp, NamedTableOp>(
[&](auto op) { return exportOperation(op); })
.Default([](auto op) {
op->emitOpError("not supported for export");
Expand Down
38 changes: 34 additions & 4 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,42 @@ DECLARE_IMPORT_FUNC(Expression, Expression, ExpressionOpInterface)
DECLARE_IMPORT_FUNC(FieldReference, Expression::FieldReference,
FieldReferenceOp)
DECLARE_IMPORT_FUNC(Literal, Expression::Literal, LiteralOp)
DECLARE_IMPORT_FUNC(NamedTable, Rel, NamedTableOp)
DECLARE_IMPORT_FUNC(NamedTable, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(Plan, Plan, PlanOp)
DECLARE_IMPORT_FUNC(PlanRel, PlanRel, PlanRelOp)
DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface)

/// Imports the provided `RelCommon` message by producing an `EmitOp` that
/// expresses the `Emit` message if it exists.
///
/// **This function must be called at the end of the import function of every
/// `Rel` message with a `RelCommon` message.**
///
/// The provided `inputOp` is the op that was imported by the `Rel` message
/// containing the provided `RelCommon` message. The function returns the
/// `EmitOp` if one was created and the `inputOp` otherwise.
static mlir::FailureOr<RelOpInterface>
importMaybeEmit(ImplicitLocOpBuilder builder, const RelCommon &message,
RelOpInterface inputOp) {
// For the `direct`, we just forward the input op.
if (message.has_direct())
return inputOp;
assert(message.has_emit() && "expected either 'direct' or 'emit'");

// For the `emit` case, we need to insert an `EmitOp`.
const proto::RelCommon::Emit &emit = message.emit();
SmallVector<int64_t> mapping;
for (int64_t index : emit.output_mapping())
mapping.push_back(index);
ArrayAttr mappingAttr = builder.getI64ArrayAttr(mapping);

Value input = inputOp->getResult(0);
auto emitOp = builder.create<EmitOp>(input, mappingAttr);

return cast<RelOpInterface>(emitOp.getOperation());
}

static mlir::FailureOr<mlir::Type> importType(MLIRContext *context,
const proto::Type &type) {

Expand Down Expand Up @@ -246,10 +276,10 @@ static mlir::FailureOr<FilterOp> importFilterRel(ImplicitLocOpBuilder builder,
builder.create<YieldOp>(conditionOp.value()->getResult(0));
}

return filterOp;
return importMaybeEmit(builder, filterRel.common(), filterOp);
}

static mlir::FailureOr<NamedTableOp>
static mlir::FailureOr<RelOpInterface>
importNamedTable(ImplicitLocOpBuilder builder, const Rel &message) {
const ReadRel &readRel = message.read();
const ReadRel::NamedTable &namedTable = readRel.named_table();
Expand Down Expand Up @@ -294,7 +324,7 @@ importNamedTable(ImplicitLocOpBuilder builder, const Rel &message) {
auto namedTableOp =
builder.create<NamedTableOp>(resultType, tableName, fieldNamesAttr);

return namedTableOp;
return importMaybeEmit(builder, readRel.common(), namedTableOp);
}

static FailureOr<PlanOp> importPlan(ImplicitLocOpBuilder builder,
Expand Down
23 changes: 23 additions & 0 deletions test/Dialect/Substrait/emit-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: structured-opt -verify-diagnostics -split-input-file %s

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
// expected-error@+2 {{'substrait.emit' op failed to infer returned types}}
// expected-error@+1 {{1 is not a valid index into 'tuple<si32>'}}
%1 = emit [1] from %0 : tuple<si32> -> tuple<si32>
yield %1 : tuple<si32>
}
}

// -----

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
// expected-error@+2 {{'substrait.emit' op failed to infer returned types}}
// expected-error@+1 {{-1 is not a valid index into 'tuple<si32>'}}
%1 = emit [-1] from %0 : tuple<si32> -> tuple<si32>
yield %1 : tuple<si32>
}
}
48 changes: 48 additions & 0 deletions test/Dialect/Substrait/emit.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: structured-opt -split-input-file %s \
// RUN: | FileCheck %s

// CHECK-LABEL: substrait.plan
// CHECK-NEXT: relation
// CHECK-NEXT: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = emit [1, 0] from %[[V0]] :
// CHECK-SAME: tuple<si1, si32> -> tuple<si32, si1>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a", "b"] : tuple<si1, si32>
%1 = emit [1, 0] from %0 : tuple<si1, si32> -> tuple<si32, si1>
yield %1 : tuple<si32, si1>
}
}

// -----

// CHECK-LABEL: substrait.plan
// CHECK-NEXT: relation
// CHECK-NEXT: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = emit [0, 0] from %[[V0]] :
// CHECK-SAME: tuple<si32> -> tuple<si32, si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = emit [0, 0] from %0 : tuple<si32> -> tuple<si32, si32>
yield %1 : tuple<si32, si32>
}
}

// -----

// CHECK-LABEL: substrait.plan
// CHECK-NEXT: relation
// CHECK-NEXT: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = emit [1] from %[[V0]] :
// CHECK-SAME: tuple<si32, si1> -> tuple<si1>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a", "b"] : tuple<si32, si1>
%1 = emit [1] from %0 : tuple<si32, si1> -> tuple<si1>
yield %1 : tuple<si1>
}
}
Loading

0 comments on commit dfd6939

Please sign in to comment.