Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: iree-org/iree-llvm-sandbox
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: e38524bd058cb73ce0d3e408509d09a040be92bd
Choose a base ref
..
head repository: iree-org/iree-llvm-sandbox
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: fca969e3ee447b2464868a4628d9bf247c534c54
Choose a head ref
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ include "iterators/Dialect/Iterators/IR/IteratorsDialect.td"
include "iterators/Dialect/Iterators/IR/IteratorsTypes.td"
include "iterators/Dialect/Tabular/IR/TabularTypes.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/OpBase.td"

class Iterators_Base_Op<string mnemonic, list<Trait> traits = []> :
@@ -119,7 +120,8 @@ def Iterators_AccumulateOp
AllMatch<["getAccumulateFunc().getArgumentTypes()[1]",
"$input.getType().dyn_cast<StreamType>().getElementType()"],
"the type of the second argument of the accumulate function must "
"match the input element type">]> {
"match the input element type">,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Accumulate the elements of a stream into one element";
let description = [{
Accumulate the elements of the input stream into a single element, i.e.,
@@ -152,8 +154,8 @@ def Iterators_AccumulateOp
);
let results = (outs Iterators_Stream:$result);
let assemblyFormat = [{
`(` $input `,` $initFuncRef `,` $accumulateFuncRef `)`
attr-dict `:` `(` type($input) `)` `->` type($result)
`(` $input `,` $initFuncRef `,` $accumulateFuncRef `)` attr-dict `:`
`(` qualified(type($input)) `)` `->` qualified(type($result))
}];
let extraClassDeclaration = [{
/// Return the init function op that the initFuncRef refers to.
@@ -168,6 +170,13 @@ def Iterators_AccumulateOp
*this, accumulateFuncRefAttr());
}
}];
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
void $cppClass::getAsmResultNames(
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
setNameFn(getResult(), "accumulated");
}
}];
}

/// Verifies that the element types of nested arrays in the $value array
@@ -410,6 +419,52 @@ def Iterators_SinkOp : Iterators_Base_Op<"sink"> {
let arguments = (ins Iterators_StreamOfLLVMStructOfNumerics:$input);
}

def Iterators_StreamToValueOp : Iterators_Base_Op<"stream_to_value",
[TypesMatchWith<"element type of input stream must match result type",
"input", "result",
"$_self.cast<StreamType>().getElementType()">]> {
let summary = "Produce a value from the first element of the input stream";
let description = [{
Consumes the first element of the given input stream and returns it. The
remaining elements from the input stream are not consumed.

The purpose of this op is to pass from "stream land" to "value land," i.e.,
to enable non-iterator ops to consume the results of a tree of iterator ops.

Example:
```mlir
%stream = ...
%value = iterators.stream_to_value %stream : !iterators.stream<i32>
```
}];
let arguments = (ins Iterators_Stream:$input);
let results = (outs AnyType:$result, I1:$hasResult);
let assemblyFormat = "$input attr-dict `:` qualified(type($input))";
}

def Iterators_ValueToStreamOp : Iterators_Op<"value_to_stream",
[TypesMatchWith<"element type of result stream must match input type",
"result", "input",
"$_self.cast<StreamType>().getElementType()">]> {
let summary = "Produce a stream with the given value as a single element";
let description = [{
Produces a stream consisting of a single element, namely the value given as
input.

The purpose of this op is to pass from "value land" to "stream land," i.e.,
to enable iterator ops to consume arbitrary values as a (singleton) stream.

Example:
```mlir
%value = arith.constant 42 : i32
%stream = iterators.value_to_stream %value : !iterators.stream<i32>
```
}];
let arguments = (ins AnyType:$input);
let results = (outs Iterators_Stream:$result);
let assemblyFormat = "$input attr-dict `:` qualified(type($result))";
}

//===----------------------------------------------------------------------===//
// Ops related to Iterator bodies.
//===----------------------------------------------------------------------===//
Original file line number Diff line number Diff line change
@@ -127,6 +127,19 @@ StateType StateTypeComputer::operator()(
return StateType::get(context, {indexType, viewType});
}

/// The state of ValueToStreamOp consists a Boolean indicating whether it has
/// already returned its value (which is initialized to false and set to true in
/// the first call to next) and the value it converts to a stream.
template <>
StateType StateTypeComputer::operator()(
ValueToStreamOp op, llvm::SmallVector<StateType> /*upstreamStateTypes*/) {
MLIRContext *context = op->getContext();
Type hasReturned = IntegerType::get(context, /*width=*/1);
Type valueType =
op->getResult(0).getType().cast<StreamType>().getElementType();
return StateType::get(context, {hasReturned, valueType});
}

/// Build IteratorInfo, assigning new unique names as needed. Takes the
/// `StateType` as a parameter, to ensure proper build order (all uses are
/// visited before any def).
@@ -174,7 +187,8 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
FilterOp,
MapOp,
ReduceOp,
TabularViewToStreamOp
TabularViewToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
llvm::SmallVector<StateType> upstreamStateTypes;
Original file line number Diff line number Diff line change
@@ -157,7 +157,7 @@ struct ConstantTupleLowering : public OpConversionPattern<ConstantTupleOp> {
// Create constant value op.
Attribute field = values[i];
Type fieldType = field.getType();
auto valueOp = rewriter.create<LLVM::ConstantOp>(loc, fieldType, field);
auto valueOp = rewriter.create<arith::ConstantOp>(loc, fieldType, field);

// Insert into struct.
structValue =
@@ -515,7 +515,7 @@ static Value buildOpenBody(ConstantStreamOp op, OpBuilder &builder,
// Insert constant zero into state.
Type i32 = b.getI32Type();
Attribute zeroAttr = b.getI32IntegerAttr(0);
Value zeroValue = b.create<LLVM::ConstantOp>(i32, zeroAttr);
Value zeroValue = b.create<arith::ConstantOp>(i32, zeroAttr);
Value updatedState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), zeroValue);

@@ -1261,7 +1261,7 @@ static Value buildOpenBody(TabularViewToStreamOp op, OpBuilder &builder,
// Insert constant zero into state.
Type i64 = b.getI64Type();
Attribute zeroAttr = b.getI64IntegerAttr(0);
Value zeroValue = b.create<LLVM::ConstantOp>(i64, zeroAttr);
Value zeroValue = b.create<arith::ConstantOp>(i64, zeroAttr);
return b.create<iterators::InsertValueOp>(initialState, b.getIndexAttr(0),
zeroValue);
}
@@ -1405,6 +1405,92 @@ static Value buildStateCreation(TabularViewToStreamOp op,
input);
}

//===----------------------------------------------------------------------===//
// ValueToStreamOp.
//===----------------------------------------------------------------------===//

/// Builds IR that sets `hasReturned` to false. Possible output:
///
/// %3 = iterators.insertvalue %false into %arg0[1] : !iterators.state<i1, i32>
static Value buildOpenBody(ValueToStreamOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> /*upstreamInfos*/) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

// Reset hasReturned to false.
Value constFalse = b.create<arith::ConstantIntOp>(/*value=*/0, /*width=*/1);
Value updatedState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), constFalse);

return updatedState;
}

/// Builds IR that returns the value in the first call and end-of-stream
/// otherwise. Pseudo-code:
///
/// if hasReturned: return {}
/// return value
///
/// Possible output:
///
/// %0 = iterators.extractvalue %arg0[0] : !iterators.state<i1, i32>
/// %true = arith.constant true
/// %1 = arith.xori %true, %0 : i1
/// %2 = iterators.extractvalue %arg0[1] : !iterators.state<i1, i32>
/// %3 = iterators.insertvalue %true into %arg0[0] : !iterators.state<i1, i32>
static llvm::SmallVector<Value, 4>
buildNextBody(ValueToStreamOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Type i1 = b.getI1Type();

// Check if the iterator has returned an element already (since it should
// return one only in the first call to next).
Value hasReturned =
b.create<iterators::ExtractValueOp>(i1, initialState, b.getIndexAttr(0));

// Compute hasNext: we have an element iff we have not returned before, i.e.,
// iff "not hasReturend". We simulate "not" with "xor true".
Value constTrue = b.create<arith::ConstantIntOp>(/*value=*/1, /*width=*/1);
Value hasNext = b.create<arith::XOrIOp>(constTrue, hasReturned);

// Extract value as next element.
Value nextElement = b.create<iterators::ExtractValueOp>(
elementType, initialState, b.getIndexAttr(1));

// Update state.
Value finalState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), constTrue);

return {finalState, hasNext, nextElement};
}

/// Forwards the initial state. The ValueToStreamOp doesn't do anything on
/// Close.
static Value buildCloseBody(ValueToStreamOp /*op*/, OpBuilder & /*builder*/,
Value initialState,
ArrayRef<IteratorInfo> /*upstreamInfos*/) {
return initialState;
}

/// Builds IR that initializes the iterator state with value. Possible output:
///
/// %0 = ...
/// %1 = iterators.undefstate : !iterators.state<i1, i32>
/// %2 = iterators.insertvalue %0 into %1[1] : !iterators.state<i1, i32>
static Value buildStateCreation(ValueToStreamOp op,
ValueToStreamOp::Adaptor adaptor,
OpBuilder &builder, StateType stateType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Value undefState = b.create<UndefStateOp>(loc, stateType);
Value value = adaptor.input();
return b.create<iterators::InsertValueOp>(undefState, b.getIndexAttr(1),
value);
}

//===----------------------------------------------------------------------===//
// Helpers for creating Open/Next/Close functions and state creation.
//===----------------------------------------------------------------------===//
@@ -1464,7 +1550,8 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder,
FilterOp,
MapOp,
ReduceOp,
TabularViewToStreamOp
TabularViewToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
return buildOpenBody(op, builder, initialState, upstreamInfos);
@@ -1483,7 +1570,8 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState,
FilterOp,
MapOp,
ReduceOp,
TabularViewToStreamOp
TabularViewToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
return buildNextBody(op, builder, initialState, upstreamInfos,
@@ -1503,7 +1591,8 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder,
FilterOp,
MapOp,
ReduceOp,
TabularViewToStreamOp
TabularViewToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
return buildCloseBody(op, builder, initialState, upstreamInfos);
@@ -1521,7 +1610,8 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder,
FilterOp,
MapOp,
ReduceOp,
TabularViewToStreamOp
TabularViewToStreamOp,
ValueToStreamOp
// clang-format on
>([&](auto op) {
using OpAdaptor = typename decltype(op)::Adaptor;
@@ -1708,6 +1798,61 @@ static SmallVector<Value> convert(SinkOp op, SinkOpAdaptor adaptor,
return {};
}

/// Converts the given StreamToValueOp to LLVM using the converted operands.
/// This consists of opening the input iterator, consuming one element (which is
/// the result of this op), and closing it again. Pseudo code:
///
/// upstream->Open()
/// value = upstream->Next()
/// upstream->Close()
///
/// Possible result:
///
/// %0 = ...
/// %1 = call @iterators.upstream.open.0(%0) : (!nested_state) -> !nested_state
/// %2:3 = call @iterators.upstream.next.0(%1) :
/// (!nested_state) -> (!nested_state, i1, !element_type)
/// %3 = call @iterators.upstream.close.0(%2#0) :
/// (!nested_state) -> !nested_state
static SmallVector<Value> convert(StreamToValueOp op,
StreamToValueOpAdaptor adaptor,
ArrayRef<IteratorInfo> upstreamInfos,
OpBuilder &rewriter) {
Location loc = op->getLoc();
ImplicitLocOpBuilder b(loc, rewriter);

// Look up IteratorInfo about the upstream iterator.
IteratorInfo upstreamInfo = upstreamInfos[0];

Type stateType = upstreamInfo.stateType;
SymbolRefAttr openFunc = upstreamInfo.openFunc;
SymbolRefAttr nextFunc = upstreamInfo.nextFunc;
SymbolRefAttr closeFunc = upstreamInfo.closeFunc;

// Open upstream iterator. ---------------------------------------------------
Value initialState = adaptor.input();
auto openCallOp = b.create<func::CallOp>(openFunc, stateType, initialState);
Value openedUpstreamState = openCallOp->getResult(0);

// Consume one element from upstream iterator --------------------------------
// Input and return types.
auto elementType = op.input().getType().cast<StreamType>().getElementType();
Type i1 = b.getI1Type();
SmallVector<Type> nextResultTypes = {stateType, i1, elementType};

func::CallOp nextCallOp =
b.create<func::CallOp>(nextFunc, nextResultTypes, openedUpstreamState);

Value consumedUpstreamState = nextCallOp->getResult(0);
Value hasValue = nextCallOp->getResult(1);
Value value = nextCallOp->getResult(2);

// Close upstream iterator. --------------------------------------------------
b.create<func::CallOp>(closeFunc, stateType, consumedUpstreamState);

return {value, hasValue};
}

/// Converts the given op to LLVM using the converted operands from the upstream
/// iterator. This function is essentially a switch between conversion functions
/// for sink and non-sink iterator ops.
@@ -1740,7 +1885,7 @@ convertIteratorOp(Operation *op, ValueRange operands, OpBuilder &builder,
return SmallVector<Value>{
convert(op, operands, opInfo, upstreamInfos, builder)};
})
.Case<SinkOp>([&](auto op) {
.Case<SinkOp, StreamToValueOp>([&](auto op) {
using OpAdaptor = typename decltype(op)::Adaptor;
OpAdaptor adaptor(operands, op->getAttrDictionary());
return convert(op, adaptor, upstreamInfos, builder);
@@ -1830,8 +1975,9 @@ static void convertIteratorOps(ModuleOp module, TypeConverter &typeConverter) {
// to the worklist *after* all of its upstream iterators.
SmallVector<Operation *, 16> workList;
module->walk<WalkOrder::PreOrder>([&](Operation *op) {
TypeSwitch<Operation *, void>(op).Case<IteratorOpInterface, SinkOp>(
[&](Operation *op) { workList.push_back(op); });
TypeSwitch<Operation *, void>(op)
.Case<IteratorOpInterface, SinkOp, StreamToValueOp>(
[&](Operation *op) { workList.push_back(op); });
});

// Convert iterator ops in worklist order.
@@ -1857,6 +2003,8 @@ static void convertIteratorOps(ModuleOp module, TypeConverter &typeConverter) {
.create<UnrealizedConversionCastOp>(
loc, convertedType, operand)
.getResult(0);
} else {
mappedOperand = operand;
}
}

@@ -1868,11 +2016,20 @@ static void convertIteratorOps(ModuleOp module, TypeConverter &typeConverter) {
convertIteratorOp(op, mappedOperands, rewriter, analysis);
TypeSwitch<Operation *>(op)
.Case<IteratorOpInterface>([&](auto op) {
// Iterator op: remember result for conversion of later ops.
assert(converted.size() == 1 &&
"Expected iterator op to be converted to one value.");
mapping.map(op->getResult(0), converted[0]);
})
.Case<StreamToValueOp>([&](auto op) {
// Special case: uses will not be converted, so replace them.
assert(converted.size() == 2 &&
"Expected StreamToValueOp to be converted to two values.");
op->getResult(0).replaceAllUsesWith(converted[0]);
op->getResult(1).replaceAllUsesWith(converted[1]);
})
.Case<SinkOp>([&](auto op) {
// Special case: no result, nothing to do.
assert(converted.empty() &&
"Expected sink op to be converted to no value.");
});
Loading