Skip to content

Commit

Permalink
[compiler] Implement DMA code specialization (#55)
Browse files Browse the repository at this point in the history
Up until now we've been performing memory copies on the compute cores
using `memref.copy` which is very slow. This PR instead starts making
use of the DMA cores DMA hardware to perform any copying from L1 to L3.
As only the DMA core can execute these instructions and only the compute
core can execute kernels, this PR also implements a "dma code
specialization" pass which copies the original compute core containing
DMA instructions, removes all compute instructions from the copy, and
inserts required synchronization between the compute and DMA functions.
  • Loading branch information
zero9178 authored Jun 27, 2024
1 parent dad35a1 commit 61ad419
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ def QuidditchSnitch_Dialect : Dialect {

let discardableAttrs = (ins
"mlir::StringAttr":$riscv_assembly,
"mlir::UnitAttr":$xdsl_compilation_failed
"mlir::UnitAttr":$xdsl_compilation_failed,
"mlir::FlatSymbolRefAttr":$dma_specialization
);

let useDefaultAttributePrinterParser = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_cc_library(
SRCS
"PromoteToL1.cpp"
"LowerL1Allocations.cpp"
"SpecializeDMACode.cpp"
DEPS
::PassesIncGen
Quidditch::Dialect::Snitch::IR::QuidditchSnitchDialect
Expand Down
16 changes: 16 additions & 0 deletions codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,20 @@ def LowerL1AllocationsPass : InterfacePass<"quidditch-lower-l1-allocations",
];
}

def SpecializeDMACodePass : Pass<"quidditch-specialize-dma-code",
"mlir::ModuleOp"> {

let description = [{
Pass performing code specialization for DMA and compute cores while
inserting required synchronization primitives.

Every function originally present in the IR will be cloned and turned into
a "dma" version.
DMA versions have all compute operations (i.e. `mmemref.microkernel`s)
removed while the original version has all DMA transfer operations removed.
Barriers are inserted where data dependencies require either transfers or
computations to have finished.
}];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "Passes.h"

#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h"
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace quidditch::Snitch {
#define GEN_PASS_DEF_SPECIALIZEDMACODEPASS
#include "Quidditch/Dialect/Snitch/Transforms/Passes.h.inc"
} // namespace quidditch::Snitch

namespace {
class SpecializeDMACode
: public quidditch::Snitch::impl::SpecializeDMACodePassBase<
SpecializeDMACode> {
public:
using Base::Base;

protected:
void runOnOperation() override;

private:
};

} // namespace

using namespace mlir;
using namespace quidditch::Snitch;

static void removeComputeOps(FunctionOpInterface dmaCode) {
dmaCode->walk([&](MemRefMicrokernelOp operation) {
// TODO: These can have results in theory which would make this crash!
operation->erase();
});
}

static void removeDmaCode(FunctionOpInterface computeCode) {
SmallVector<Operation *> toDelete;
computeCode->walk([&](Operation *operation) {
if (isa<WaitForDMATransfersOp, StartDMATransferOp>(operation))
toDelete.push_back(operation);
});
for (Operation *op : toDelete) {
op->dropAllUses();
op->erase();
}
}

static void insertBarriers(FunctionOpInterface function) {
function->walk([](Operation *operation) {
OpBuilder builder(operation->getContext());
if (isa<WaitForDMATransfersOp>(operation))
// Barrier needs to be after the wait to signal to compute ops the
// transfer is done.
builder.setInsertionPointAfter(operation);
else if (isa<StartDMATransferOp>(operation))
// Barrier needs to be before the transfer for compute ops to signal
// that a computation is done.
// TODO: This is overly conservative and could be optimized somewhere.
builder.setInsertionPoint(operation);
else
return;

builder.create<BarrierOp>(operation->getLoc());
});
}

void SpecializeDMACode::runOnOperation() {
auto *dialect = getContext().getLoadedDialect<QuidditchSnitchDialect>();
SymbolTable table(getOperation());
for (auto function : getOperation().getOps<FunctionOpInterface>()) {
if (function.isDeclaration())
continue;

insertBarriers(function);

FunctionOpInterface clone = function.clone();
clone.setName((clone.getName() + "$dma").str());
table.insert(clone, function->getIterator());
dialect->getDmaSpecializationAttrHelper().setAttr(
function, FlatSymbolRefAttr::get(clone));

removeComputeOps(clone);
removeDmaCode(function);
}
}
31 changes: 21 additions & 10 deletions codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "Quidditch/Conversion/Passes.h"
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h"
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h"
#include "Quidditch/Dialect/Snitch/Transforms/Passes.h"

#include "compiler/plugins/target/LLVMCPU/LinkerTool.h"
Expand Down Expand Up @@ -183,12 +184,14 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
return builder.create<memref::AllocaOp>(
loc, memRefType, dynamicSizes, builder.getI64IntegerAttr(alignment));
};
BufferizationOptions::MemCpyFn memcpyFn =
[](OpBuilder &builder, Location loc, Value from, Value to) {
// TODO: DMA copy.
createLinalgCopyOp(builder, loc, from, to);
return success();
};
BufferizationOptions::MemCpyFn memcpyFn = [](OpBuilder &builder,
Location loc, Value from,
Value to) {
Value token =
builder.create<quidditch::Snitch::StartDMATransferOp>(loc, from, to);
builder.create<quidditch::Snitch::WaitForDMATransfersOp>(loc, token);
return success();
};

FunctionLikeNest(modulePassManager)
.addPass(createEliminateEmptyTensorsPass)
Expand All @@ -213,6 +216,10 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
.addPass(createCanonicalizerPass)
.addPass(createLinalgGeneralizeNamedOpsPass);

modulePassManager.addPass(quidditch::Snitch::createSpecializeDMACodePass());
FunctionLikeNest(modulePassManager)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass);
modulePassManager.addPass(quidditch::createConvertToRISCVPass(
{targetOptions.xDSLOptPath, targetOptions.assertCompiled}));

Expand Down Expand Up @@ -344,10 +351,14 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
} else {
// xDSL kernel.

// TODO: Replace with real DMA pointer. For now just a place holder that
// is not used in the runtime but makes the dispatch recognizeable as
// an xDSL dispatch.
dmaPointer = llvmFunc;
// TODO: This should use the attribute attached to the LLVM::LLVMFuncOp.
dmaPointer =
llvmModule->getFunction((llvmFunc->getName() + "$dma").str());
if (!dmaPointer) {
module.emitError()
<< "failed to find DMA code for " << exportOp.getName();
return nullptr;
}
}
if (!llvmFunc)
continue;
Expand Down
14 changes: 10 additions & 4 deletions runtime/runtime/src/Quidditch/dispatch/dispatch.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,18 @@ void quidditch_dispatch_queue_workgroup(
}

void quidditch_dispatch_execute_workgroups() {
// Avoid needlessly waking any cores if no workgroups are queued.
if (nextCoreToUse == 0) return;

// First wake workers.
snrt_cluster_hw_barrier();
quidditch_dispatch_start_executing_workgroup();

// Then wait for workers to be done.
quidditch_dispatch_wait_for_workgroup();
}

void quidditch_dispatch_start_executing_workgroup() {
snrt_cluster_hw_barrier();
}

void quidditch_dispatch_wait_for_workgroup() {
snrt_cluster_hw_barrier();
reset_workgroup_state();
}
5 changes: 5 additions & 0 deletions runtime/runtime/src/Quidditch/dispatch/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ void quidditch_dispatch_queue_workgroup(

/// Executes all queued workgroups and waits for them to finish.
void quidditch_dispatch_execute_workgroups();

/// Executes all queued workgroups and waits for them to finish.
void quidditch_dispatch_start_executing_workgroup();

void quidditch_dispatch_wait_for_workgroup();
10 changes: 9 additions & 1 deletion runtime/runtime/src/Quidditch/executable/executable.c
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ iree_status_t quidditch_executable_issue_dispatch_inline(
// Snitch distributes workgroups to clusters.
// I.e., one workgroup runs on one cluster.
// TODO: Subgroup distribution.
iree_hal_executable_dispatch_v0_t const dmaCoreFunction =
((quidditch_executable_export_table_v0_t*)exports)
->dma_core_ptrs[ordinal];
for (uint32_t z = 0; z < workgroup_count_z; ++z) {
workgroup_state.workgroup_id_z = z;
for (uint32_t y = 0; y < workgroup_count_y; ++y) {
Expand All @@ -238,7 +241,12 @@ iree_status_t quidditch_executable_issue_dispatch_inline(
workgroup_state.workgroup_id_x = x;

quidditch_dispatch_queue_workgroup(&workgroup_state);
quidditch_dispatch_execute_workgroups();
quidditch_dispatch_start_executing_workgroup();

dmaCoreFunction(&executable->environment, dispatch_state,
&workgroup_state);

quidditch_dispatch_wait_for_workgroup();
}
}
}
Expand Down

0 comments on commit 61ad419

Please sign in to comment.