Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement lowering for AccumulateOp.
Browse files Browse the repository at this point in the history
ingomueller-net committed Apr 13, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d323284 commit 8466965
Showing 4 changed files with 410 additions and 0 deletions.
18 changes: 18 additions & 0 deletions lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -57,6 +57,23 @@ class StateTypeComputer {
TypeConverter typeConverter;
};

/// The state of AccumulateOp consists of the state of its upstream iterator,
/// i.e., the state of the iterator that produces its input stream, the initial
/// value of the accumulator, and a Boolean indicating whether the iterator has
/// returned a result already (which is initialized to false and set to true in
/// the first call to next in order to ensure that only a single result is
/// returned).
template <>
StateType
StateTypeComputer::operator()(AccumulateOp op,
llvm::SmallVector<StateType> upstreamStateTypes) {
MLIRContext *context = op->getContext();
Type hasReturned = IntegerType::get(context, /*width=*/1);
Type initValType = op.getInitVal().getType();
return StateType::get(context,
{upstreamStateTypes[0], initValType, hasReturned});
}

/// The state of ConstantStreamOp consists of a single number that corresponds
/// to the index of the next struct returned by the iterator.
template <>
@@ -180,6 +197,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
// TODO: Verify that operands do not come from bbArgs.
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
242 changes: 242 additions & 0 deletions lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -308,6 +308,244 @@ struct PrintOpLowering : public OpConversionPattern<PrintOp> {
}
};

//===----------------------------------------------------------------------===//
// AccumulateOp.
//===----------------------------------------------------------------------===//

/// Builds IR that opens the nested upstream iterator and sets `hasReturned` to
/// false. Possible output:
///
/// %0 = iterators.extractvalue %arg0[0] :
/// <!upstream_state, i1> -> !upstream_state
/// %1 = call @iterators.upstream.open.0(%0) :
/// (!upstream_state) -> !upstream_state
/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
/// <!upstream_state, i1>
/// %false = arith.constant false
/// %3 = iterators.insertvalue %false into %2[1] :
/// !iterators.state<!upstream_state, i1>
static Value buildOpenBody(AccumulateOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

Type upstreamStateType = upstreamInfos[0].stateType;

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, b.getIndexAttr(0));

// Call Open on upstream.
SymbolRefAttr openFunc = upstreamInfos[0].openFunc;
auto openCallOp =
b.create<func::CallOp>(openFunc, upstreamStateType, initialUpstreamState);

// Update upstream state.
Value updatedUpstreamState = openCallOp->getResult(0);
Value updatedState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), updatedUpstreamState);

// Reset hasReturned to false.
Value constFalse = b.create<arith::ConstantIntOp>(/*value=*/0, /*width=*/1);
updatedState = b.create<iterators::InsertValueOp>(
updatedState, b.getIndexAttr(2), constFalse);

return updatedState;
}

/// Builds IR that consumes all elements of the upstream iterator and combines
/// them into a single one using the given accumulate function. Pseudo-code:
///
/// if hasReturned: return {}
/// hasReturned = True
/// accumulator = initVal
/// while (next = upstream->Next()):
/// accumulator = accumulate(accumulator, next)
/// return accumulator
///
/// Possible output:
///
/// %upstream_state = iterators.extractvalue %arg0[0] : !state_type
/// %init_val = iterators.extractvalue %arg0[1] : !state_type
/// %has_returned = iterators.extractvalue %arg0[2] : !state_type
/// %2:2 = scf.if %2 -> (!upstream_state, !element_type) {
/// scf.yield %upstream_state, %init_val : !upstream_state, !element_type
/// } else {
/// %5:3 = scf.while (%arg1 = %upsteram_state, %arg2 = %init_val) :
/// (!upstream_state, !element_type) ->
/// (!upstream_state, !element_type, !element_type) {
/// %6:3 = func.call @iterators.upstream.next.0(%arg1) :
/// (!upstream_state) -> (!upstream_state, i1, !element_type)
/// scf.condition(%6#1) %8#0, %arg2, %8#2 :
/// !upstream_state, !element_type, !element_type
//// } do {
/// ^bb0(%arg1: !upstream_state, %arg2: !element_type, %arg3: !element_type):
/// %8 = func.call @accumulate_func(%arg2, %arg3) :
/// (!element_type, !element_type) -> !element_type
/// scf.yield %arg1, %8 : !upstream_state, !element_type
/// }
/// scf.yield %7#0, %7#1 : !upstream_state, !element_type
/// }
/// %true = arith.constant true
/// %4 = arith.xori %true, %1 : i1
/// %state_0 = iterators.insertvalue %3#0 into %arg0[0] : !state_type
/// %state_1 = iterators.insertvalue %true into %state_0[1] : !state_type
static llvm::SmallVector<Value, 4>
buildNextBody(AccumulateOp op, OpBuilder &builder, Value initialState,
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Type i1 = b.getI1Type();

// Extract input element type.
StreamType inputStreamType = op.getInput().getType().cast<StreamType>();
Type inputElementType = inputStreamType.getElementType();

// Extract upstream state and init value.
Type upstreamStateType = upstreamInfos[0].stateType;
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, b.getIndexAttr(0));
Value initValue = b.create<iterators::ExtractValueOp>(
elementType, initialState, b.getIndexAttr(1));

// Check if the iterator has returned an element already (since it should
// return one only in the first call to next).
Value hasReturned =
b.create<iterators::ExtractValueOp>(i1, initialState, b.getIndexAttr(2));
SmallVector<Type> ifReturnTypes{upstreamStateType, elementType};
auto ifOp = b.create<scf::IfOp>(
hasReturned,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

// Don't modify state; return init value.
b.create<scf::YieldOp>(ValueRange{initialUpstreamState, initValue});
},
/*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
ImplicitLocOpBuilder b(loc, builder);

// Create while loop using init value as initial accumulator.
SmallVector<Value> whileInputs = {initialUpstreamState, initValue};
SmallVector<Type> whileResultTypes = {
upstreamStateType, // Updated upstream state.
elementType, // Accumulator.
inputElementType // Element from last next call.
};
scf::WhileOp whileOp = b.create<scf::WhileOp>(
whileResultTypes, whileInputs,
/*beforeBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);

Value upstreamState = args[0];
Value accumulator = args[1];

// Call next function.
SmallVector<Type> nextResultTypes = {upstreamStateType, i1,
inputElementType};
SymbolRefAttr nextFunc = upstreamInfos[0].nextFunc;
auto nextCall = b.create<func::CallOp>(nextFunc, nextResultTypes,
upstreamState);

Value updatedUpstreamState = nextCall->getResult(0);
Value hasNext = nextCall->getResult(1);
Value maybeNextElement = nextCall->getResult(2);
b.create<scf::ConditionOp>(
hasNext, ValueRange{updatedUpstreamState, accumulator,
maybeNextElement});
},
/*afterBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange args) {
ImplicitLocOpBuilder b(loc, builder);

Value upstreamState = args[0];
Value accumulator = args[1];
Value nextElement = args[2];

// Call accumulate function.
auto accumulateCall =
b.create<func::CallOp>(elementType, op.getAccumulateFuncRef(),
ValueRange{accumulator, nextElement});
Value newAccumulator = accumulateCall->getResult(0);

b.create<scf::YieldOp>(ValueRange{upstreamState, newAccumulator});
});

Value updatedState = whileOp->getResult(0);
Value accumulator = whileOp->getResult(1);

b.create<scf::YieldOp>(ValueRange{updatedState, accumulator});
});

// Compute hasNext: we have an element iff we have not returned before, i.e.,
// iff "not hasReturend". We simulate "not" with "xor true".
Value constTrue = b.create<arith::ConstantIntOp>(/*value=*/1, /*width=*/1);
Value hasNext = b.create<arith::XOrIOp>(constTrue, hasReturned);

// Update state.
Value finalUpstreamState = ifOp->getResult(0);
Value finalState = b.create<iterators::InsertValueOp>(
initialState, b.getIndexAttr(0), finalUpstreamState); // upstreamState
finalState = b.create<iterators::InsertValueOp>(finalState, b.getIndexAttr(2),
constTrue); // hasReturned
Value nextElement = ifOp->getResult(1);

return {finalState, hasNext, nextElement};
}

/// Builds IR that closes the nested upstream iterator. Possible output:
///
/// %0 = iterators.extractvalue %arg0[0] :
/// !iterators.state<!upstream_state, i1> -> !upstream_state
/// %1 = call @iterators.upstream.close.0(%0) :
/// (!upstream_state) -> !upstream_state
/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
/// !iterators.state<!upstream_state, i1>
static Value buildCloseBody(AccumulateOp op, OpBuilder &builder,
Value initialState,
ArrayRef<IteratorInfo> upstreamInfos) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);

Type upstreamStateType = upstreamInfos[0].stateType;

// Extract upstream state.
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
upstreamStateType, initialState, b.getIndexAttr(0));

// Call Close on upstream.
SymbolRefAttr closeFunc = upstreamInfos[0].closeFunc;
auto closeCallOp = b.create<func::CallOp>(closeFunc, upstreamStateType,
initialUpstreamState);

// Update upstream state.
Value updatedUpstreamState = closeCallOp->getResult(0);
return b
.create<iterators::InsertValueOp>(initialState, b.getIndexAttr(0),
updatedUpstreamState)
.getResult();
}

/// Builds IR that initializes the iterator state with the state of the upstream
/// iterator. Possible output:
///
/// %0 = ...
/// %1 = arith.constant false
/// %2 = iterators.createstate(%0, %1) : !iterators.state<!upstream_state, i1>
static Value buildStateCreation(AccumulateOp op, AccumulateOp::Adaptor adaptor,
OpBuilder &builder, StateType stateType) {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, builder);
Value upstreamState = adaptor.getInput();
Value initVal = adaptor.getInitVal();
Value constFalse = b.create<arith::ConstantIntOp>(/*value=*/0, /*width=*/1);
return b.create<iterators::CreateStateOp>(
stateType, ValueRange{upstreamState, initVal, constFalse});
}

//===----------------------------------------------------------------------===//
// ConstantStreamOp.
//===----------------------------------------------------------------------===//
@@ -1543,6 +1781,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder,
return llvm::TypeSwitch<Operation *, Value>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
@@ -1563,6 +1802,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState,
return llvm::TypeSwitch<Operation *, llvm::SmallVector<Value, 4>>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
@@ -1584,6 +1824,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder,
return llvm::TypeSwitch<Operation *, Value>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
@@ -1603,6 +1844,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder,
return llvm::TypeSwitch<Operation *, Value>(op)
.Case<
// clang-format off
AccumulateOp,
ConstantStreamOp,
FilterOp,
MapOp,
71 changes: 71 additions & 0 deletions test/Conversion/IteratorsToLLVM/accumulate.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: iterators-opt %s -convert-iterators-to-llvm \
// RUN: | FileCheck --enable-var-scope %s

func.func private @sum_tuple(
%acc : tuple<i32>, %val : tuple<i32>) -> tuple<i32> {
%acci = tuple.to_elements %acc : tuple<i32>
%vali = tuple.to_elements %val : tuple<i32>
%i = arith.addi %acci, %vali : i32
%result = tuple.from_elements %i : tuple<i32>
return %result : tuple<i32>
}

// CHECK-LABEL: func.func private @iterators.accumulate.close.{{[0-9]+}}(
// CHECK-SAME: %[[ARG0:.*]]: !iterators.state<[[upstreamStateType:!iterators.state<[^>]*>]], tuple<i32>, i1>) ->
// CHECK-SAME: !iterators.state<[[upstreamStateType]], tuple<i32>, i1> {
// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: %[[V1:.*]] = call @iterators.{{.*}}.close.{{.*}}(%[[V0]]) : ([[upstreamStateType]]) -> [[upstreamStateType]]
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V1]] into %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: return %[[V2]] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>

// CHECK-LABEL: func.func private @iterators.accumulate.next.{{[0-9]+}}(
// CHECK-SAME: %[[ARG0:.*]]: !iterators.state<[[upstreamStateType:!iterators.state<[^>]*>]], tuple<i32>, i1>) ->
// CHECK-SAME: (!iterators.state<[[upstreamStateType]], tuple<i32>, i1>, i1, tuple<i32>) {
// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[ARG0]][1] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: %[[V3:.*]] = iterators.extractvalue %[[ARG0]][2] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: %[[V4:.*]]:2 = scf.if %[[V3]] -> ([[upstreamStateType]], tuple<i32>) {
// CHECK-NEXT: scf.yield %[[V1]], %[[V2]] : [[upstreamStateType]], tuple<i32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[V5:.*]]:3 = scf.while (%[[arg1:.*]] = %[[V1]], %[[arg2:.*]] = %[[V2]]) : ([[upstreamStateType]], tuple<i32>) -> ([[upstreamStateType]], tuple<i32>, tuple<i32>) {
// CHECK-NEXT: %[[V6:.*]]:3 = func.call @iterators.{{.*}}.next.{{.*}}(%[[arg1]]) : ([[upstreamStateType]]) -> ([[upstreamStateType]], i1, tuple<i32>)
// CHECK-NEXT: scf.condition(%[[V6]]#1) %[[V6]]#0, %[[arg2]], %[[V6]]#2 : [[upstreamStateType]], tuple<i32>, tuple<i32>
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[arg1:.*]]: [[upstreamStateType]], %[[arg2:.*]]: tuple<i32>, %[[arg3:.*]]: tuple<i32>):
// CHECK-NEXT: %[[V7:.*]] = func.call @sum_tuple(%[[arg2]], %[[arg3]]) : (tuple<i32>, tuple<i32>) -> tuple<i32>
// CHECK-NEXT: scf.yield %[[arg1]], %[[V7]] : [[upstreamStateType]], tuple<i32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[V5]]#0, %[[V5]]#1 : [[upstreamStateType]], tuple<i32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[V8:.*]] = arith.constant true
// CHECK-NEXT: %[[V9:.*]] = arith.xori %[[V8]], %[[V3]] : i1
// CHECK-NEXT: %[[Va:.*]] = iterators.insertvalue %[[V4]]#0 into %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: %[[Vb:.*]] = iterators.insertvalue %[[V8]] into %[[Va]][2] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: return %[[Vb]], %[[V9]], %[[V4]]#1 : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>, i1, tuple<i32>

// CHECK-LABEL: func.func private @iterators.accumulate.open.{{[0-9]+}}(
// CHECK-SAME: %[[ARG0:.*]]: !iterators.state<[[upstreamStateType:!iterators.state<[^>]*>]], tuple<i32>, i1>) ->
// CHECK-SAME: !iterators.state<[[upstreamStateType]], tuple<i32>, i1> {
// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: %[[V2:.*]] = call @iterators.{{.*}}.open.{{.*}}(%[[V1]]) : ([[upstreamStateType]]) -> [[upstreamStateType]]
// CHECK-NEXT: %[[V3:.*]] = iterators.insertvalue %[[V2]] into %[[ARG0]][0] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: %[[V4:.*]] = arith.constant false
// CHECK-NEXT: %[[V5:.*]] = iterators.insertvalue %[[V4]] into %[[V3]][2] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
// CHECK-NEXT: return %[[V5]] : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>

// CHECK-LABEL: func.func @main()
func.func @main() {
// CHECK-DAG: %[[V0:.*]] = iterators.createstate{{.*}} : [[upstreamStateType:!iterators.state<[^>]*>]]
%input = "iterators.constantstream"() { value = [] } : () -> (!iterators.stream<tuple<i32>>)

// CHECK-DAG: %[[V1:.*]] = tuple.from_elements %{{.*}} : tuple<i32>
%hundred = arith.constant 0 : i32
%init_value = tuple.from_elements %hundred : tuple<i32>

// CHECK-DAG: %[[V2:.*]] = arith.constant false
// CHECK-NEXT: %[[V3:.*]] = iterators.createstate(%[[V0]], %[[V1]], %[[V2]]) : !iterators.state<[[upstreamStateType]], tuple<i32>, i1>
%accumulated = iterators.accumulate(%input, %init_value) with @sum_tuple
: (!iterators.stream<tuple<i32>>) -> !iterators.stream<tuple<i32>>
return
// CHECK-NEXT: return
}
79 changes: 79 additions & 0 deletions test/Integration/Dialect/Iterators/CPU/accumulate.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: iterators-opt %s \
// RUN: -convert-iterators-to-llvm \
// RUN: -decompose-iterator-states \
// RUN: -decompose-tuples \
// RUN: -convert-func-to-llvm \
// RUN: -convert-scf-to-cf -convert-cf-to-llvm \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: | FileCheck %s

func.func private @accumulate_sum_tuple(
%acc : tuple<i32>, %val : tuple<i32>) -> tuple<i32> {
%acci = tuple.to_elements %acc : tuple<i32>
%vali = tuple.to_elements %val : tuple<i32>
%i = arith.addi %acci, %vali : i32
%result = tuple.from_elements %i : tuple<i32>
return %result : tuple<i32>
}

// CHECK-LABEL: test_accumulate_sum_tuple
// CHECK-NEXT: (160)
// CHECK-NEXT: -
func.func @test_accumulate_sum_tuple() {
iterators.print("test_accumulate_sum_tuple")
%input = "iterators.constantstream"()
{ value = [[0 : i32], [10 : i32], [20 : i32], [30 : i32]] }
: () -> (!iterators.stream<tuple<i32>>)
%hundred = arith.constant 100 : i32
%init_value = tuple.from_elements %hundred : tuple<i32>
%accumulated = iterators.accumulate(%input, %init_value)
with @accumulate_sum_tuple
: (!iterators.stream<tuple<i32>>) -> !iterators.stream<tuple<i32>>
"iterators.sink"(%accumulated) : (!iterators.stream<tuple<i32>>) -> ()
return
}

func.func private @accumulate_avg_tuple(
%acc : tuple<i32, i32>, %val : tuple<i32>) -> tuple<i32, i32> {
%cnt, %sum = tuple.to_elements %acc : tuple<i32, i32>
%vali = tuple.to_elements %val : tuple<i32>
%one = arith.constant 1 : i32
%new_cnt = arith.addi %cnt, %one : i32
%new_sum = arith.addi %sum, %vali : i32
%result = tuple.from_elements %new_cnt, %new_sum : tuple<i32, i32>
return %result : tuple<i32, i32>
}

func.func private @avg(%input : tuple<i32, i32>) -> tuple<f32> {
%cnt, %sum = tuple.to_elements %input : tuple<i32, i32>
%cntf = arith.sitofp %cnt : i32 to f32
%sumf = arith.sitofp %sum : i32 to f32
%avg = arith.divf %sumf, %cntf : f32
%result = tuple.from_elements %avg : tuple<f32>
return %result : tuple<f32>
}

// CHECK-LABEL: test_accumulate_avg_tuple
// CHECK-NEXT: (15)
// CHECK-NEXT: -
func.func @test_accumulate_avg_tuple() {
iterators.print("test_accumulate_avg_tuple")
%input = "iterators.constantstream"()
{ value = [[0 : i32], [10 : i32], [20 : i32], [30 : i32]] }
: () -> (!iterators.stream<tuple<i32>>)
%zero = arith.constant 0 : i32
%init_value = tuple.from_elements %zero, %zero : tuple<i32, i32>
%accumulated = iterators.accumulate(%input, %init_value)
with @accumulate_avg_tuple
: (!iterators.stream<tuple<i32>>) -> !iterators.stream<tuple<i32, i32>>
%mapped = "iterators.map"(%accumulated) {mapFuncRef = @avg}
: (!iterators.stream<tuple<i32, i32>>) -> (!iterators.stream<tuple<f32>>)
"iterators.sink"(%mapped) : (!iterators.stream<tuple<f32>>) -> ()
return
}

func.func @main() {
call @test_accumulate_sum_tuple() : () -> ()
call @test_accumulate_avg_tuple() : () -> ()
return
}

0 comments on commit 8466965

Please sign in to comment.