Skip to content

Commit

Permalink
[quidditch_snitch] Basic implementations of DMA operations (#54)
Browse files Browse the repository at this point in the history
This PR implements DMA copy and wait operations. Snitch's DMA is
asnychronous and has special hardware support for faster transfers
between L1 and L3 than doing it manually. The currentl lowering to LLVM
simple performs calls into the snitch runtime for convenience with the
idea of these being later inlined via LTO once implemented.

Noteworthy is that the support for various memref shapes, strides and
dimension is currently very limited but enough for what IREE generates.
  • Loading branch information
zero9178 authored Jun 27, 2024
1 parent e7f66ef commit 60afb0c
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 1 deletion.
146 changes: 146 additions & 0 deletions codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,125 @@ struct L1MemoryViewOpLowering : ConvertOpToLLVMPattern<L1MemoryViewOp> {
}
};

struct StartDMATransferOp1DLowering
: ConvertOpToLLVMPattern<StartDMATransferOp> {

LLVM::LLVMFuncOp dmaStart1DFunc;

StartDMATransferOp1DLowering(LLVM::LLVMFuncOp dmaStart1DFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter, /*benefit=*/2),
dmaStart1DFunc(dmaStart1DFunc) {
setHasBoundedRewriteRecursion();
}

LogicalResult match(StartDMATransferOp op) const override {
MemRefLayoutAttrInterface sourceLayout =
op.getSource().getType().getLayout();
MemRefLayoutAttrInterface destLayout = op.getDest().getType().getLayout();
if (sourceLayout && !sourceLayout.isIdentity())
return failure();

if (destLayout && !destLayout.isIdentity())
return failure();

return success();
}

void rewrite(StartDMATransferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefDescriptor sourceDescriptor(adaptor.getSource());
MemRefDescriptor destDescriptor(adaptor.getDest());

Value source = sourceDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType());
Value dest = destDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType());

MemRefType sourceMemRef = op.getSource().getType();
SmallVector<Value> dynamicSizes;
for (std::int64_t dim : sourceMemRef.getShape())
if (ShapedType::isDynamic(dim))
dynamicSizes.push_back(
sourceDescriptor.size(rewriter, op->getLoc(), dim));

SmallVector<Value> sizes;
SmallVector<Value> strides;
Value totalSize;
getMemRefDescriptorSizes(op->getLoc(), op.getSource().getType(),
dynamicSizes, rewriter, sizes, strides, totalSize);

rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, dmaStart1DFunc,
ValueRange{
dest,
source,
totalSize,
});
}
};

struct StartDMATransferOp2DLowering
: ConvertOpToLLVMPattern<StartDMATransferOp> {

LLVM::LLVMFuncOp dmaStart2DFunc;

StartDMATransferOp2DLowering(LLVM::LLVMFuncOp dmaStart2DFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter), dmaStart2DFunc(dmaStart2DFunc) {}

LogicalResult
matchAndRewrite(StartDMATransferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getSource().getType().getRank() != 2)
return failure();

MemRefDescriptor sourceDescriptor(adaptor.getSource());
MemRefDescriptor destDescriptor(adaptor.getDest());

Value source = sourceDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType());
Value dest = destDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType());

Value size = rewriter.create<LLVM::ConstantOp>(
op->getLoc(),
rewriter.getI32IntegerAttr(llvm::divideCeil(
op.getSource().getType().getElementTypeBitWidth(), 8)));
size = rewriter.create<LLVM::MulOp>(
op->getLoc(), size, sourceDescriptor.size(rewriter, op->getLoc(), 0));

Value sourceStride = sourceDescriptor.stride(rewriter, op->getLoc(), 1);
Value destStride = destDescriptor.stride(rewriter, op->getLoc(), 1);

rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, dmaStart2DFunc,
ValueRange{dest, source, size, destStride, sourceStride,
sourceDescriptor.size(rewriter, op->getLoc(), 1)});
return success();
}
};

struct WaitForDMATransfersOpLowering
: ConvertOpToLLVMPattern<WaitForDMATransfersOp> {

LLVM::LLVMFuncOp waitFunc;

WaitForDMATransfersOpLowering(LLVM::LLVMFuncOp waitFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter), waitFunc(waitFunc) {}

LogicalResult
matchAndRewrite(WaitForDMATransfersOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: This should wait for only a specific transfer not all.
// for (Value token : adaptor.getTokens())
// rewriter.create<LLVM::CallOp>(op->getLoc(), waitFunc, token);
rewriter.create<LLVM::CallOp>(op->getLoc(), waitFunc, ValueRange());
rewriter.eraseOp(op);
return success();
}
};

struct BarrierOpLowering : ConvertOpToLLVMPattern<BarrierOp> {

LLVM::LLVMFuncOp barrierFunc;
Expand All @@ -87,11 +206,35 @@ void ConvertSnitchToLLVM::runOnOperation() {
// TODO: This is horribly hardcoded when it shouldn't be.
options.overrideIndexBitwidth(32);
LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis);
typeConverter.addConversion([](DMATokenType token) {
return IntegerType::get(token.getContext(), 32);
});

auto builder = OpBuilder::atBlockEnd(getOperation().getBody());
auto ptrType = builder.getType<LLVM::LLVMPointerType>();
IntegerType i32 = builder.getI32Type();
IntegerType sizeT = i32;
auto dmaStart1D = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "snrt_dma_start_1d",
LLVM::LLVMFunctionType::get(i32,
ArrayRef<Type>{ptrType, ptrType, sizeT}));
dmaStart1D->setAttr("hal.import.bitcode", builder.getUnitAttr());

auto dmaStart2D = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "snrt_dma_start_2d",
LLVM::LLVMFunctionType::get(
i32, ArrayRef<Type>{ptrType, ptrType, sizeT, sizeT, sizeT, sizeT}));
dmaStart2D->setAttr("hal.import.bitcode", builder.getUnitAttr());

// TODO: This should wait for only a specific transfer not all.
// This is currently bugged in the snitch_cluster repo and potentially
// the hardware.
auto dmaWait = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "snrt_dma_wait_all",
LLVM::LLVMFunctionType::get(builder.getType<LLVM::LLVMVoidType>(),
ArrayRef<Type>{}));
dmaWait->setAttr("hal.import.bitcode", builder.getUnitAttr());

auto barrier = builder.create<LLVM::LLVMFuncOp>(
builder.getUnknownLoc(), "snrt_cluster_hw_barrier",
LLVM::LLVMFunctionType::get(builder.getType<LLVM::LLVMVoidType>(),
Expand All @@ -100,6 +243,9 @@ void ConvertSnitchToLLVM::runOnOperation() {

RewritePatternSet patterns(&getContext());
patterns.insert<L1MemoryViewOpLowering>(typeConverter);
patterns.insert<StartDMATransferOp1DLowering>(dmaStart1D, typeConverter);
patterns.insert<StartDMATransferOp2DLowering>(dmaStart2D, typeConverter);
patterns.insert<WaitForDMATransfersOpLowering>(dmaWait, typeConverter);
patterns.insert<BarrierOpLowering>(barrier, typeConverter);

LLVMConversionTarget target(getContext());
Expand Down
15 changes: 14 additions & 1 deletion codegen/compiler/src/Quidditch/Dialect/Snitch/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ iree_cc_library(
"QuidditchSnitchDialect.h.inc"
"QuidditchSnitchOps.cpp.inc"
"QuidditchSnitchOps.h.inc"
"QuidditchSnitchTypes.cpp.inc"
"QuidditchSnitchTypes.h.inc"
SRCS
"QuidditchSnitchAttrs.cpp"
"QuidditchSnitchDialect.cpp"
"QuidditchSnitchOps.cpp"
"QuidditchSnitchTypes.cpp"
DEPS
::QuidditchSnitchAttrsGen
::QuidditchSnitchDialectGen
::QuidditchSnitchOpsGen
::QuidditchSnitchTypesGen
LLVMSupport
MLIRIR
MLIRInferTypeOpInterface
Expand Down Expand Up @@ -48,7 +52,6 @@ iree_tablegen_library(
--gen-dialect-defs QuidditchSnitchDialect.cpp.inc
)


iree_tablegen_library(
NAME
QuidditchSnitchAttrsGen
Expand All @@ -58,3 +61,13 @@ iree_tablegen_library(
--gen-attrdef-decls QuidditchSnitchAttrs.h.inc
--gen-attrdef-defs QuidditchSnitchAttrs.cpp.inc
)

iree_tablegen_library(
NAME
QuidditchSnitchTypesGen
TD_FILE
"QuidditchSnitchTypes.td"
OUTS
--gen-typedef-decls QuidditchSnitchTypes.h.inc
--gen-typedef-defs QuidditchSnitchTypes.cpp.inc
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "QuidditchSnitchAttrs.h"
#include "QuidditchSnitchOps.h"
#include "QuidditchSnitchTypes.h"

#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp.inc"

Expand All @@ -16,4 +17,8 @@ void QuidditchSnitchDialect::initialize() {
#define GET_ATTRDEF_LIST
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp.inc"
>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def QuidditchSnitch_Dialect : Dialect {
);

let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "QuidditchSnitchTypes.h"

#define GET_OP_CLASSES
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h.inc"
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHOPS

include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td"
include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
Expand Down Expand Up @@ -121,6 +122,51 @@ def QuidditchSnitch_L1MemoryViewOp : QuidditchSnitch_Op<"l1_memory_view",
}];
}

def QuidditchSnitch_StartDMATransferOp : QuidditchSnitch_Op<"start_dma_transfer",
[MemoryEffects<[MemWrite]>, SameOperandsElementType, SameOperandsShape]> {

let description = [{
Operation performing a DMA transfer from one MemRef to another.
At least one of the two MemRefs must be in L1 memory.
The shapes (including dynamic ones at runtime) of both MemRefs must be
identical with different strides and offsets allowed.

The DMA operation is likely (but not guaranteed) to run asynchronous and
its completion only guaranteed by executing the `wait_for_dma_transfers`
operation with the token returned by this operation or a later one.
}];

// TODO: In reality what the constraint here is, is that both of them must either be contiguous or if a dimension is
// not contiguous in one of them, all dimensions prior to that must be contiguous (i.e. of equal size).
// Can support higher dimensions and funkier strides once needed.
let arguments = (ins
Arg<MemRefRankOf<[AnyType], [1, 2]>, "source", [MemRead]>:$source,
Arg<MemRefRankOf<[AnyType], [1, 2]>, "destination", [MemWrite]>:$dest
);

let results = (outs QuidditchSnitch_DMATokenType:$token);

let assemblyFormat = [{
`from` $source `:` type($source) `to` $dest `:` type($dest) attr-dict
}];
}

def QuidditchSnitch_WaitForDMATransfersOp
: QuidditchSnitch_Op<"wait_for_dma_transfers"> {

let description = [{
Operation awaiting for DMA transfers denoted by its tokens to be finished.
}];

let arguments = (ins
Variadic<QuidditchSnitch_DMATokenType>:$tokens
);

let assemblyFormat = [{
$tokens `:` type($tokens) attr-dict
}];
}

def QuidditchSnitch_BarrierOp : QuidditchSnitch_Op<"barrier"> {
let assemblyFormat = [{
attr-dict
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "QuidditchSnitchTypes.h"

#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"

#include "QuidditchSnitchDialect.h"

#define GET_TYPEDEF_CLASSES
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp.inc"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

#pragma once

#include "mlir/IR/Types.h"

#define GET_TYPEDEF_CLASSES
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.h.inc"
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHTYPES
#define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHTYPES

include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td"
include "mlir/IR/AttrTypeBase.td"

class QuidditchSnitch_Type<string name, list<Trait> traits = []> :
TypeDef<QuidditchSnitch_Dialect, name, traits>;

def QuidditchSnitch_DMATokenType : QuidditchSnitch_Type<"DMAToken"> {
let mnemonic = "dma_token";

let description = [{
Type representing a potentially active DMA transfer.
}];
}

#endif
15 changes: 15 additions & 0 deletions codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer_1d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: quidditch-opt %s --quidditch-convert-snitch-to-llvm | FileCheck %s

// CHECK-LABEL: @test
func.func @test(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>) -> !quidditch_snitch.dma_token {
// CHECK: %[[ARG0_PTR:.*]] = llvm.extractvalue %{{.*}}[1]
// CHECK: %[[ARG1_PTR:.*]] = llvm.extractvalue %{{.*}}[1]
// CHECK: %[[ARG0_SIZE:.*]] = llvm.extractvalue %{{.*}}[3, 0]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ARG0_SIZE]]]
// CHECK: %[[SIZE:.*]] = llvm.ptrtoint %[[GEP]]
// CHECK: %[[R:.*]] = llvm.call @snrt_dma_start_1d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[SIZE]])
%0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<?xf32> to %arg1 : memref<?xf32>
// CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[R]]
// CHECK: return %[[C]]
return %0 : !quidditch_snitch.dma_token
}
9 changes: 9 additions & 0 deletions codegen/tests/Conversion/ConvertSnitchToLLVM/dma_wait.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: quidditch-opt %s --quidditch-convert-snitch-to-llvm | FileCheck %s

// CHECK-LABEL: @test
func.func @test(%arg0 : !quidditch_snitch.dma_token) {
// TODO: This should be a call to snrt_dma_wait but is currently bugged.
// CHECK: call @snrt_dma_wait_all()
quidditch_snitch.wait_for_dma_transfers %arg0 : !quidditch_snitch.dma_token
return
}

0 comments on commit 60afb0c

Please sign in to comment.