Skip to content

Commit

Permalink
[quidditch_snitch] Reintroduce tensor.microkernel op (#106)
Browse files Browse the repository at this point in the history
1b20d75
previously removed the `tensor.microkernel` operation as it at the time
seemed not worth the extra code.

Since then, we noted that microkernels execute in a more asynchronous
manner due to Snitch's asynchronous FPU requiring the use of an explicit
`microkernel_fence` operation. Optimizing the placement of these is
easier done in tensor land, making the operation more worth it.
Additionally, more experience in bufferization lead to simplifying its
implementation by restricting `microkernel_yield` to only tensor
operations.

A tensor counterpart of `microkernel_fence` called `sync_tensor` has
also been added which makes a result tensor of a `tensor.microkernel`
operation available. It bufferizes to `microkernel_fence` and its
placement could be further optimized in the future. The conservative
placement of `microkernel_fence` operations was also removed from
`speicalize-dma-code` leading to less barriers and `microkernel_fence`s.
  • Loading branch information
zero9178 authored Aug 15, 2024
1 parent 03aaad4 commit 5803d44
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,199 @@ static void printRISCVAssembly(OpAsmPrinter &opAsmPrinter, Operation *,
}
}

//===----------------------------------------------------------------------===//
// TensorMicrokernelOp::RegionBranchOpInterface
//===----------------------------------------------------------------------===//

void TensorMicrokernelOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent()) {
regions.emplace_back(&getBody());
return;
}
regions.emplace_back(getResults());
}

void TensorMicrokernelOp::getRegionInvocationBounds(
ArrayRef<Attribute>, SmallVectorImpl<InvocationBounds> &invocationBounds) {
invocationBounds.push_back({1, 1});
}

//===----------------------------------------------------------------------===//
// TensorMicrokernelOp::BufferizableOpInterface
//===----------------------------------------------------------------------===//

AliasingOpOperandList
TensorMicrokernelOp::getAliasingOpOperands(Value value,
const AnalysisState &state) {
return {{
&getYieldOp()
.getResultsMutable()[cast<OpResult>(value).getResultNumber()],
BufferRelation::Equivalent,
/*isDefinite=*/true,
}};
}

FailureOr<BaseMemRefType>
TensorMicrokernelOp::getBufferType(Value value,
const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
Value corresponding =
getYieldOp().getResults()[cast<OpResult>(value).getResultNumber()];
if (auto memRefType = dyn_cast<BaseMemRefType>(corresponding.getType()))
return memRefType;

return bufferization::getBufferType(corresponding, options, invocationStack);
}

LogicalResult
TensorMicrokernelOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options) {
SmallVector<Value> newYields;
for (Value result : getYieldOp().getResults()) {
if (!isa<TensorType>(result.getType())) {
newYields.push_back(result);
continue;
}
auto bufferType = bufferization::getBuffer(rewriter, result, options);
if (failed(bufferType))
return failure();
newYields.push_back(*bufferType);
}

SetVector<Value> inputs;
WalkResult walkResult = walk([&](Operation *operation) {
for (Value value : operation->getOperands()) {
if (isa<TensorType>(value.getType())) {
FailureOr<Value> newInput = getBuffer(rewriter, value, options);
if (failed(newInput))
return WalkResult::interrupt();
value = *newInput;
}

if (getBody().isAncestor(value.getParentRegion()))
continue;
inputs.insert(value);
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return failure();

auto replacement =
rewriter.create<MemRefMicrokernelOp>(getLoc(), inputs.getArrayRef());
Block *newBlock = replacement.createEntryBlock();
{
OpBuilder::InsertionGuard guard{rewriter};
rewriter.setInsertionPointToStart(newBlock);

rewriter.mergeBlocks(&getBody().front(), newBlock);
rewriter.eraseOp(newBlock->getTerminator());

SmallVector<Value> vector = inputs.takeVector();
rewriter.setInsertionPointToStart(newBlock);
for (auto [oldV, newV] : llvm::zip(vector, newBlock->getArguments()))
rewriter.replaceUsesWithIf(oldV, newV, [&](OpOperand &operand) {
return replacement.getBody().isAncestor(
operand.getOwner()->getParentRegion());
});
}

replaceOpWithBufferizedValues(rewriter, *this, newYields);
return success();
}

//===----------------------------------------------------------------------===//
// MicrokernelYieldOp::BufferizableOpInterface
//===----------------------------------------------------------------------===//

bool MicrokernelYieldOp::bufferizesToMemoryRead(
OpOperand &, const bufferization::AnalysisState &) {
return false;
}

bool MicrokernelYieldOp::bufferizesToMemoryWrite(OpOperand &,
const AnalysisState &) {
return false;
}

AliasingValueList MicrokernelYieldOp::getAliasingValues(OpOperand &opOperand,
const AnalysisState &) {
return {{getParentOp()->getResult(opOperand.getOperandNumber()),
BufferRelation::Equivalent, /*isDefinite=*/true}};
}

bool MicrokernelYieldOp::mustBufferizeInPlace(OpOperand &,
const AnalysisState &) {
// Yield operands always bufferize inplace. Otherwise, an alloc + copy
// may be generated inside the block. We should not return/yield allocations
// when possible.
return true;
}

LogicalResult
MicrokernelYieldOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options) {
SmallVector<Value> newResults;
for (auto &&[index, value] : llvm::enumerate(getResults())) {
if (!isa<TensorType>(value.getType())) {
newResults.push_back(value);
continue;
}

FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();

newResults.push_back(*maybeBuffer);
}
replaceOpWithNewBufferizedOp<MicrokernelYieldOp>(rewriter, *this, newResults);
return success();
}

//===----------------------------------------------------------------------===//
// SyncTensorOp::BufferizableOpInterface
//===----------------------------------------------------------------------===//

bool SyncTensorOp::bufferizesToMemoryRead(
OpOperand &, const bufferization::AnalysisState &) {
return false;
}

bool SyncTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
const AnalysisState &) {
assert(opOperand == getInputMutable());
// The op making the asynchronous result of the microkernel available is
// effectively a write operation to the MemRef.
return true;
}

AliasingValueList SyncTensorOp::getAliasingValues(OpOperand &opOperand,
const AnalysisState &) {
assert(opOperand == getInputMutable());
return {{getResult(), BufferRelation::Equivalent, /*isDefinite=*/true}};
}

bool SyncTensorOp::mustBufferizeInPlace(OpOperand &opOperand,
const AnalysisState &) {
assert(opOperand == getInputMutable());
// The operation must bufferize in place as a copy inserted by the
// bufferization framework would be inserted prior to the
// `microkernel_fence` operation and not semantically equivalent.
return true;
}

LogicalResult SyncTensorOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options) {
FailureOr<Value> inputTensorBuffer = getBuffer(rewriter, getInput(), options);
if (failed(inputTensorBuffer))
return failure();

rewriter.create<MicrokernelFenceOp>(getLoc());
replaceOpWithBufferizedValues(rewriter, *this, *inputTensorBuffer);
return success();
}

//===----------------------------------------------------------------------===//
// MemRefMicrokernelOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,81 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class QuidditchSnitch_Op<string mnemonic, list<Trait> traits = []> :
Op<QuidditchSnitch_Dialect, mnemonic, traits>;

def QuidditchSnitch_TensorMicrokernelOp : QuidditchSnitch_Op<"tensor.microkernel",
[SingleBlock, NoRegionArguments, RecursivelySpeculatable, RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<BufferizableOpInterface, ["bufferize",
"getAliasingOpOperands", "getBufferType"]>]> {

let description = [{
Pre-bufferization version of `memref.microkernel`.
Unlike `memref.microkernel` it is not isolated from above and may also
return tensor operations as outputs via `microkernel_yield`.

Like `memref.microkernel`, operations within the kernel may be executing
asynchronously and cannot be used directly.
A `sync_tensor` operation must be used to make any result tensor of this
operation available.
Failing to do so results in unspecified values within the tensor.
}];

let results = (outs Variadic<AnyRankedTensor>:$results);

let regions = (region SizedRegion<1>:$body);

let assemblyFormat = [{
(`->` type($results)^ )? $body attr-dict
}];

let extraClassDeclaration = [{
MicrokernelYieldOp getYieldOp() {
return llvm::cast<MicrokernelYieldOp>(getBody().back().getTerminator());
}

mlir::Block* createEntryBlock();
}];
}

def QuidditchSnitch_MicrokernelYieldOp
: QuidditchSnitch_Op<"microkernel_yield", [Pure, Terminator,
HasParent<"TensorMicrokernelOp">, ReturnLike,
DeclareOpInterfaceMethods<BufferizableOpInterface,
["bufferize", "bufferizesToMemoryRead", "bufferizesToMemoryWrite",
"getAliasingValues", "mustBufferizeInPlace"]>]> {
let arguments = (ins Variadic<AnyRankedTensor>:$results);

let assemblyFormat = [{
$results (`:` type($results)^)? attr-dict
}];
}

def QuidditchSnitch_SyncTensorOp : QuidditchSnitch_Op<"sync_tensor",
[AllTypesMatch<["result", "input"]>, Pure,
DeclareOpInterfaceMethods<BufferizableOpInterface,
["bufferizesToMemoryRead", "bufferizesToMemoryWrite", "getAliasingValues",
"bufferize", "mustBufferizeInPlace"]>]> {

let description = [{
Performs synchronization of a tensor returned by a `tensor.microkernel`
operation.
The resulting tensor is guaranteed to consist of the results of any
operations performed by the `tensor.microkernel` operation.
}];

let arguments = (ins
AnyRankedTensor:$input
);

let results = (outs
AnyRankedTensor:$result
);

let assemblyFormat = [{
$input `:` type($result) attr-dict
}];
}

def QuidditchSnitch_MemRefMicrokernelOp
: QuidditchSnitch_Op<"memref.microkernel", [IsolatedFromAbove, SingleBlock,
NoTerminator]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,60 +27,28 @@ class FormMicrokernels
};
} // namespace

static void outlineOpsToFunction(MutableArrayRef<linalg::LinalgOp> ops) {
if (ops.empty())
return;

auto builder = OpBuilder(ops.front());

SetVector<Value> inputs;
for (linalg::LinalgOp computeOp : ops) {
inputs.insert(computeOp->getOperands().begin(),
computeOp->getOperands().end());

computeOp.walk([&](Operation *operation) {
for (Value value : operation->getOperands()) {
if (computeOp->getParentRegion()->isProperAncestor(
value.getParentRegion()))
continue;

inputs.insert(value);
}
});
}

auto kernelOp = builder.create<MemRefMicrokernelOp>(ops.front()->getLoc(),
inputs.getArrayRef());

Block *block = kernelOp.createEntryBlock();
builder.setInsertionPointToStart(block);

for (Operation *op : ops) {
op->remove();
builder.insert(op);
}

SmallVector<Value> vector = inputs.takeVector();
for (auto [oldV, newV] : llvm::zip(vector, block->getArguments()))
oldV.replaceUsesWithIf(newV, [&](OpOperand &operand) {
return kernelOp.getBody().isAncestor(
operand.getOwner()->getParentRegion());
});
}

void FormMicrokernels::runOnOperation() {
FunctionOpInterface func = getOperation();

SmallVector<linalg::LinalgOp> outlinedOps;
func.walk([&](Block *block) {
for (Operation &op : *block) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp || !linalgOp.hasPureBufferSemantics()) {
outlineOpsToFunction(outlinedOps);
outlinedOps.clear();
continue;
}
outlinedOps.push_back(linalgOp);
func.walk([](linalg::LinalgOp linalgOp) {
if (!linalgOp.hasPureTensorSemantics())
return;

auto builder = OpBuilder(linalgOp);
auto kernelOp = builder.create<TensorMicrokernelOp>(
linalgOp.getLoc(), linalgOp->getResultTypes());
for (auto [oldResult, newResult] :
llvm::zip_equal(linalgOp->getResults(), kernelOp.getResults())) {
oldResult.replaceAllUsesWith(
builder.create<SyncTensorOp>(linalgOp.getLoc(), newResult));
}

Block *block = &kernelOp.getBody().emplaceBlock();
builder.setInsertionPointToStart(block);

linalgOp->remove();
builder.insert(linalgOp);
builder.create<MicrokernelYieldOp>(linalgOp->getLoc(),
linalgOp->getResults());
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,10 @@ static void removeDmaCode(FunctionOpInterface computeCode) {
static void insertBarriers(FunctionOpInterface function) {
function->walk([](Operation *operation) {
OpBuilder builder(operation->getContext());
if (isa<WaitForDMATransfersOp>(operation)) {
if (isa<WaitForDMATransfersOp, MicrokernelFenceOp>(operation)) {
// Barrier needs to be after the wait to signal to compute ops the
// transfer is done.
builder.setInsertionPointAfter(operation);
} else if (isa<StartDMATransferOp>(operation)) {
// Barrier needs to be before the transfer for compute ops to signal
// that a computation is done.
// TODO: This is overly conservative and could be optimized somewhere.
builder.setInsertionPoint(operation);
builder.create<MicrokernelFenceOp>(operation->getLoc());
} else
return;

Expand Down
6 changes: 3 additions & 3 deletions codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
})
.addPass(createCanonicalizerPass)
.addPass(createCSEPass)
.addPass(createLoopInvariantCodeMotionPass);
.addPass(createLoopInvariantCodeMotionPass)
.addPass(quidditch::Snitch::createFormMicrokernelsPass);

BufferizationOptions::AllocationFn allocationFn =
[](OpBuilder &builder, Location loc, MemRefType memRefType,
Expand Down Expand Up @@ -246,8 +247,7 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
.addPass(createCSEPass)
.addPass(createLoopInvariantCodeMotionPass)
.addPass(createLinalgGeneralizeNamedOpsPass)
.addPass(quidditch::createRemoveTrivialLoopsPass)
.addPass(quidditch::Snitch::createFormMicrokernelsPass);
.addPass(quidditch::createRemoveTrivialLoopsPass);

modulePassManager.addPass(quidditch::Snitch::createSpecializeDMACodePass());
FunctionLikeNest(modulePassManager)
Expand Down
Loading

0 comments on commit 5803d44

Please sign in to comment.