Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[quidditch_snitch] Basic implementations of DMA operations
Browse files Browse the repository at this point in the history
This PR implements DMA copy and wait operations. Snitch's DMA is asnychronous and has special hardware support for faster transfers between L1 and L3 than doing it manually. The currentl lowering to LLVM simple performs calls into the snitch runtime for convenience with the idea of these being later inlined via LTO once implemented.

Noteworthy is that the support for various memref shapes, strides and dimension is currently very limited but enough for what IREE generates.
zero9178 committed Jun 27, 2024
1 parent e7f66ef commit 4900a27
Showing 11 changed files with 274 additions and 1 deletion.
146 changes: 146 additions & 0 deletions codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -62,6 +62,125 @@ struct L1MemoryViewOpLowering : ConvertOpToLLVMPattern<L1MemoryViewOp> {
}
};

struct StartDMATransferOp1DLowering
: ConvertOpToLLVMPattern<StartDMATransferOp> {

LLVM::LLVMFuncOp dmaStart1DFunc;

StartDMATransferOp1DLowering(LLVM::LLVMFuncOp dmaStart1DFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter, /*benefit=*/2),
dmaStart1DFunc(dmaStart1DFunc) {
setHasBoundedRewriteRecursion();
}

LogicalResult match(StartDMATransferOp op) const override {
MemRefLayoutAttrInterface sourceLayout =
op.getSource().getType().getLayout();
MemRefLayoutAttrInterface destLayout = op.getDest().getType().getLayout();
if (sourceLayout && !sourceLayout.isIdentity())
return failure();

if (destLayout && !destLayout.isIdentity())
return failure();

return success();
}

void rewrite(StartDMATransferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefDescriptor sourceDescriptor(adaptor.getSource());
MemRefDescriptor destDescriptor(adaptor.getDest());

Value source = sourceDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType());
Value dest = destDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType());

MemRefType sourceMemRef = op.getSource().getType();
SmallVector<Value> dynamicSizes;
for (std::int64_t dim : sourceMemRef.getShape())
if (ShapedType::isDynamic(dim))
dynamicSizes.push_back(
sourceDescriptor.size(rewriter, op->getLoc(), dim));

SmallVector<Value> sizes;
SmallVector<Value> strides;
Value totalSize;
getMemRefDescriptorSizes(op->getLoc(), op.getSource().getType(),
dynamicSizes, rewriter, sizes, strides, totalSize);

rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, dmaStart1DFunc,
ValueRange{
dest,
source,
totalSize,
});
}
};

struct StartDMATransferOp2DLowering
: ConvertOpToLLVMPattern<StartDMATransferOp> {

LLVM::LLVMFuncOp dmaStart2DFunc;

StartDMATransferOp2DLowering(LLVM::LLVMFuncOp dmaStart2DFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter), dmaStart2DFunc(dmaStart2DFunc) {}

LogicalResult
matchAndRewrite(StartDMATransferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getSource().getType().getRank() != 2)
return failure();

MemRefDescriptor sourceDescriptor(adaptor.getSource());
MemRefDescriptor destDescriptor(adaptor.getDest());

Value source = sourceDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType());
Value dest = destDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType());

Value size = rewriter.create<LLVM::ConstantOp>(
op->getLoc(),
rewriter.getI32IntegerAttr(llvm::divideCeil(
op.getSource().getType().getElementTypeBitWidth(), 8)));
size = rewriter.create<LLVM::MulOp>(
op->getLoc(), size, sourceDescriptor.size(rewriter, op->getLoc(), 0));

Value sourceStride = sourceDescriptor.stride(rewriter, op->getLoc(), 1);
Value destStride = destDescriptor.stride(rewriter, op->getLoc(), 1);

rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, dmaStart2DFunc,
ValueRange{dest, source, size, destStride, sourceStride,
sourceDescriptor.size(rewriter, op->getLoc(), 1)});
return success();
}
};

struct WaitForDMATransfersOpLowering
: ConvertOpToLLVMPattern<WaitForDMATransfersOp> {

LLVM::LLVMFuncOp waitFunc;

WaitForDMATransfersOpLowering(LLVM::LLVMFuncOp waitFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter), waitFunc(waitFunc) {}

LogicalResult
matchAndRewrite(WaitForDMATransfersOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: This should wait for only a specific transfer not all.
// for (Value token : adaptor.getTokens())
// rewriter.create<LLVM::CallOp>(op->getLoc(), waitFunc, token);
rewriter.create<LLVM::CallOp>(op->getLoc(), waitFunc, ValueRange());
rewriter.eraseOp(op);
return success();
}
};

struct BarrierOpLowering : ConvertOpToLLVMPattern<BarrierOp> {

LLVM::LLVMFuncOp barrierFunc;
@@ -87,11 +206,35 @@ void ConvertSnitchToLLVM::runOnOperation() {
// TODO: This is horribly hardcoded when it shouldn't be.
options.overrideIndexBitwidth(32);
LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis);
typeConverter.addConversion([](DMATokenType token) {
return IntegerType::get(token.getContext(), 32);
});

auto builder = OpBuilder::atBlockEnd(getOperation().getBody());
auto ptrType = builder.getType<LLVM::LLVMPointerType>();
IntegerType i32 = builder.getI32Type();
IntegerType sizeT = i32;
auto dmaStart1D = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "snrt_dma_start_1d",
LLVM::LLVMFunctionType::get(i32,
ArrayRef<Type>{ptrType, ptrType, sizeT}));
dmaStart1D->setAttr("hal.import.bitcode", builder.getUnitAttr());

auto dmaStart2D = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "snrt_dma_start_2d",
LLVM::LLVMFunctionType::get(
i32, ArrayRef<Type>{ptrType, ptrType, sizeT, sizeT, sizeT, sizeT}));
dmaStart2D->setAttr("hal.import.bitcode", builder.getUnitAttr());

// TODO: This should wait for only a specific transfer not all.
// This is currently bugged in the snitch_cluster repo and potentially
// the hardware.
auto dmaWait = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "snrt_dma_wait_all",
LLVM::LLVMFunctionType::get(builder.getType<LLVM::LLVMVoidType>(),
ArrayRef<Type>{}));
dmaWait->setAttr("hal.import.bitcode", builder.getUnitAttr());

auto barrier = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "snrt_cluster_hw_barrier",
LLVM::LLVMFunctionType::get(builder.getType<LLVM::LLVMVoidType>(),
@@ -100,6 +243,9 @@ void ConvertSnitchToLLVM::runOnOperation() {

RewritePatternSet patterns(&getContext());
patterns.insert<L1MemoryViewOpLowering>(typeConverter);
patterns.insert<StartDMATransferOp1DLowering>(dmaStart1D, typeConverter);
patterns.insert<StartDMATransferOp2DLowering>(dmaStart2D, typeConverter);
patterns.insert<WaitForDMATransfersOpLowering>(dmaWait, typeConverter);
patterns.insert<BarrierOpLowering>(barrier, typeConverter);

LLVMConversionTarget target(getContext());
15 changes: 14 additions & 1 deletion codegen/compiler/src/Quidditch/Dialect/Snitch/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -13,14 +13,18 @@ iree_cc_library(
"QuidditchSnitchDialect.h.inc"
"QuidditchSnitchOps.cpp.inc"
"QuidditchSnitchOps.h.inc"
"QuidditchSnitchTypes.cpp.inc"
"QuidditchSnitchTypes.h.inc"
SRCS
"QuidditchSnitchAttrs.cpp"
"QuidditchSnitchDialect.cpp"
"QuidditchSnitchOps.cpp"
"QuidditchSnitchTypes.cpp"
DEPS
::QuidditchSnitchAttrsGen
::QuidditchSnitchDialectGen
::QuidditchSnitchOpsGen
::QuidditchSnitchTypesGen
LLVMSupport
MLIRIR
MLIRInferTypeOpInterface
@@ -48,7 +52,6 @@ iree_tablegen_library(
--gen-dialect-defs QuidditchSnitchDialect.cpp.inc
)


iree_tablegen_library(
NAME
QuidditchSnitchAttrsGen
@@ -58,3 +61,13 @@ iree_tablegen_library(
--gen-attrdef-decls QuidditchSnitchAttrs.h.inc
--gen-attrdef-defs QuidditchSnitchAttrs.cpp.inc
)

iree_tablegen_library(
NAME
QuidditchSnitchTypesGen
TD_FILE
"QuidditchSnitchTypes.td"
OUTS
--gen-typedef-decls QuidditchSnitchTypes.h.inc
--gen-typedef-defs QuidditchSnitchTypes.cpp.inc
)
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

#include "QuidditchSnitchAttrs.h"
#include "QuidditchSnitchOps.h"
#include "QuidditchSnitchTypes.h"

#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp.inc"

@@ -16,4 +17,8 @@ void QuidditchSnitchDialect::initialize() {
#define GET_ATTRDEF_LIST
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp.inc"
>();
}
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ def QuidditchSnitch_Dialect : Dialect {
);

let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
}

#endif
Original file line number Diff line number Diff line change
@@ -9,5 +9,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "QuidditchSnitchTypes.h"

#define GET_OP_CLASSES
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h.inc"
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
#define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHOPS

include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td"
include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
@@ -121,6 +122,51 @@ def QuidditchSnitch_L1MemoryViewOp : QuidditchSnitch_Op<"l1_memory_view",
}];
}

def QuidditchSnitch_StartDMATransferOp : QuidditchSnitch_Op<"start_dma_transfer",
[MemoryEffects<[MemWrite]>, SameOperandsElementType, SameOperandsShape]> {

let description = [{
Operation performing a DMA transfer from one MemRef to another.
At least one of the two MemRefs must be in L1 memory.
The shapes (including dynamic ones at runtime) of both MemRefs must be
identical with different strides and offsets allowed.

The DMA operation is likely (but not guaranteed) to run asynchronous and
its completion only guaranteed by executing the `wait_for_dma_transfers`
operation with the token returned by this operation or a later one.
}];

// TODO: In reality what the constraint here is, is that both of them must either be contiguous or if a dimension is
// not contiguous in one of them, all dimensions prior to that must be contiguous (i.e. of equal size).
// Can support higher dimensions and funkier strides once needed.
let arguments = (ins
Arg<MemRefRankOf<[AnyType], [1, 2]>, "source", [MemRead]>:$source,
Arg<MemRefRankOf<[AnyType], [1, 2]>, "destination", [MemWrite]>:$dest
);

let results = (outs QuidditchSnitch_DMATokenType:$token);

let assemblyFormat = [{
`from` $source `:` type($source) `to` $dest `:` type($dest) attr-dict
}];
}

def QuidditchSnitch_WaitForDMATransfersOp
: QuidditchSnitch_Op<"wait_for_dma_transfers"> {

let description = [{
Operation awaiting for DMA transfers denoted by its tokens to be finished.
}];

let arguments = (ins
Variadic<QuidditchSnitch_DMATokenType>:$tokens
);

let assemblyFormat = [{
$tokens `:` type($tokens) attr-dict
}];
}

def QuidditchSnitch_BarrierOp : QuidditchSnitch_Op<"barrier"> {
let assemblyFormat = [{
attr-dict
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "QuidditchSnitchTypes.h"

#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"

#include "QuidditchSnitchDialect.h"

#define GET_TYPEDEF_CLASSES
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp.inc"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

#pragma once

#include "mlir/IR/Types.h"

#define GET_TYPEDEF_CLASSES
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.h.inc"
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHTYPES
#define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHTYPES

include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td"
include "mlir/IR/AttrTypeBase.td"

class QuidditchSnitch_Type<string name, list<Trait> traits = []> :
TypeDef<QuidditchSnitch_Dialect, name, traits>;

def QuidditchSnitch_DMATokenType : QuidditchSnitch_Type<"DMAToken"> {
let mnemonic = "dma_token";

let description = [{
Type representing a potentially active DMA transfer.
}];
}

#endif
15 changes: 15 additions & 0 deletions codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer_1d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: quidditch-opt %s --quidditch-convert-snitch-to-llvm | FileCheck %s

// CHECK-LABEL: @test
func.func @test(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>) -> !quidditch_snitch.dma_token {
// CHECK: %[[ARG0_PTR:.*]] = llvm.extractvalue %{{.*}}[1]
// CHECK: %[[ARG1_PTR:.*]] = llvm.extractvalue %{{.*}}[1]
// CHECK: %[[ARG0_SIZE:.*]] = llvm.extractvalue %{{.*}}[3, 0]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ARG0_SIZE]]]
// CHECK: %[[SIZE:.*]] = llvm.ptrtoint %[[GEP]]
// CHECK: %[[R:.*]] = llvm.call @snrt_dma_start_1d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[SIZE]])
%0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<?xf32> to %arg1 : memref<?xf32>
// CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[R]]
// CHECK: return %[[C]]
return %0 : !quidditch_snitch.dma_token
}
9 changes: 9 additions & 0 deletions codegen/tests/Conversion/ConvertSnitchToLLVM/dma_wait.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: quidditch-opt %s --quidditch-convert-snitch-to-llvm | FileCheck %s

// CHECK-LABEL: @test
func.func @test(%arg0 : !quidditch_snitch.dma_token) {
// TODO: This should be a call to snrt_dma_wait but is currently bugged.
// CHECK: call @snrt_dma_wait_all()
quidditch_snitch.wait_for_dma_transfers %arg0 : !quidditch_snitch.dma_token
return
}

0 comments on commit 4900a27

Please sign in to comment.