From 4900a2781b0cf70e082d40ed92e6738780216eba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Thu, 27 Jun 2024 19:05:52 +0100 Subject: [PATCH] [quidditch_snitch] Basic implementations of DMA operations This PR implements DMA copy and wait operations. Snitch's DMA is asnychronous and has special hardware support for faster transfers between L1 and L3 than doing it manually. The currentl lowering to LLVM simple performs calls into the snitch runtime for convenience with the idea of these being later inlined via LTO once implemented. Noteworthy is that the support for various memref shapes, strides and dimension is currently very limited but enough for what IREE generates. --- .../Conversion/ConvertSnitchToLLVM.cpp | 146 ++++++++++++++++++ .../Dialect/Snitch/IR/CMakeLists.txt | 15 +- .../Snitch/IR/QuidditchSnitchDialect.cpp | 5 + .../Snitch/IR/QuidditchSnitchDialect.td | 1 + .../Dialect/Snitch/IR/QuidditchSnitchOps.h | 2 + .../Dialect/Snitch/IR/QuidditchSnitchOps.td | 46 ++++++ .../Snitch/IR/QuidditchSnitchTypes.cpp | 11 ++ .../Dialect/Snitch/IR/QuidditchSnitchTypes.h | 7 + .../Dialect/Snitch/IR/QuidditchSnitchTypes.td | 18 +++ .../ConvertSnitchToLLVM/dma_transfer_1d.mlir | 15 ++ .../ConvertSnitchToLLVM/dma_wait.mlir | 9 ++ 11 files changed, 274 insertions(+), 1 deletion(-) create mode 100644 codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp create mode 100644 codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.h create mode 100644 codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td create mode 100644 codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer_1d.mlir create mode 100644 codegen/tests/Conversion/ConvertSnitchToLLVM/dma_wait.mlir diff --git a/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp b/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp index f212745..7568264 100644 --- a/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp +++ b/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp @@ -62,6 +62,125 @@ struct L1MemoryViewOpLowering : ConvertOpToLLVMPattern { } }; +struct StartDMATransferOp1DLowering + : ConvertOpToLLVMPattern { + + LLVM::LLVMFuncOp dmaStart1DFunc; + + StartDMATransferOp1DLowering(LLVM::LLVMFuncOp dmaStart1DFunc, + const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter, /*benefit=*/2), + dmaStart1DFunc(dmaStart1DFunc) { + setHasBoundedRewriteRecursion(); + } + + LogicalResult match(StartDMATransferOp op) const override { + MemRefLayoutAttrInterface sourceLayout = + op.getSource().getType().getLayout(); + MemRefLayoutAttrInterface destLayout = op.getDest().getType().getLayout(); + if (sourceLayout && !sourceLayout.isIdentity()) + return failure(); + + if (destLayout && !destLayout.isIdentity()) + return failure(); + + return success(); + } + + void rewrite(StartDMATransferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefDescriptor sourceDescriptor(adaptor.getSource()); + MemRefDescriptor destDescriptor(adaptor.getDest()); + + Value source = sourceDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType()); + Value dest = destDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType()); + + MemRefType sourceMemRef = op.getSource().getType(); + SmallVector dynamicSizes; + for (std::int64_t dim : sourceMemRef.getShape()) + if (ShapedType::isDynamic(dim)) + dynamicSizes.push_back( + sourceDescriptor.size(rewriter, op->getLoc(), dim)); + + SmallVector sizes; + SmallVector strides; + Value totalSize; + getMemRefDescriptorSizes(op->getLoc(), op.getSource().getType(), + dynamicSizes, rewriter, sizes, strides, totalSize); + + rewriter.replaceOpWithNewOp(op, dmaStart1DFunc, + ValueRange{ + dest, + source, + totalSize, + }); + } +}; + +struct StartDMATransferOp2DLowering + : ConvertOpToLLVMPattern { + + LLVM::LLVMFuncOp dmaStart2DFunc; + + StartDMATransferOp2DLowering(LLVM::LLVMFuncOp dmaStart2DFunc, + const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter), dmaStart2DFunc(dmaStart2DFunc) {} + + LogicalResult + matchAndRewrite(StartDMATransferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getSource().getType().getRank() != 2) + return failure(); + + MemRefDescriptor sourceDescriptor(adaptor.getSource()); + MemRefDescriptor destDescriptor(adaptor.getDest()); + + Value source = sourceDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType()); + Value dest = destDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType()); + + Value size = rewriter.create( + op->getLoc(), + rewriter.getI32IntegerAttr(llvm::divideCeil( + op.getSource().getType().getElementTypeBitWidth(), 8))); + size = rewriter.create( + op->getLoc(), size, sourceDescriptor.size(rewriter, op->getLoc(), 0)); + + Value sourceStride = sourceDescriptor.stride(rewriter, op->getLoc(), 1); + Value destStride = destDescriptor.stride(rewriter, op->getLoc(), 1); + + rewriter.replaceOpWithNewOp( + op, dmaStart2DFunc, + ValueRange{dest, source, size, destStride, sourceStride, + sourceDescriptor.size(rewriter, op->getLoc(), 1)}); + return success(); + } +}; + +struct WaitForDMATransfersOpLowering + : ConvertOpToLLVMPattern { + + LLVM::LLVMFuncOp waitFunc; + + WaitForDMATransfersOpLowering(LLVM::LLVMFuncOp waitFunc, + const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter), waitFunc(waitFunc) {} + + LogicalResult + matchAndRewrite(WaitForDMATransfersOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: This should wait for only a specific transfer not all. + // for (Value token : adaptor.getTokens()) + // rewriter.create(op->getLoc(), waitFunc, token); + rewriter.create(op->getLoc(), waitFunc, ValueRange()); + rewriter.eraseOp(op); + return success(); + } +}; + struct BarrierOpLowering : ConvertOpToLLVMPattern { LLVM::LLVMFuncOp barrierFunc; @@ -87,11 +206,35 @@ void ConvertSnitchToLLVM::runOnOperation() { // TODO: This is horribly hardcoded when it shouldn't be. options.overrideIndexBitwidth(32); LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis); + typeConverter.addConversion([](DMATokenType token) { + return IntegerType::get(token.getContext(), 32); + }); auto builder = OpBuilder::atBlockEnd(getOperation().getBody()); auto ptrType = builder.getType(); IntegerType i32 = builder.getI32Type(); IntegerType sizeT = i32; + auto dmaStart1D = builder.create( + builder.getUnknownLoc(), "snrt_dma_start_1d", + LLVM::LLVMFunctionType::get(i32, + ArrayRef{ptrType, ptrType, sizeT})); + dmaStart1D->setAttr("hal.import.bitcode", builder.getUnitAttr()); + + auto dmaStart2D = builder.create( + builder.getUnknownLoc(), "snrt_dma_start_2d", + LLVM::LLVMFunctionType::get( + i32, ArrayRef{ptrType, ptrType, sizeT, sizeT, sizeT, sizeT})); + dmaStart2D->setAttr("hal.import.bitcode", builder.getUnitAttr()); + + // TODO: This should wait for only a specific transfer not all. + // This is currently bugged in the snitch_cluster repo and potentially + // the hardware. + auto dmaWait = builder.create( + builder.getUnknownLoc(), "snrt_dma_wait_all", + LLVM::LLVMFunctionType::get(builder.getType(), + ArrayRef{})); + dmaWait->setAttr("hal.import.bitcode", builder.getUnitAttr()); + auto barrier = builder.create( builder.getUnknownLoc(), "snrt_cluster_hw_barrier", LLVM::LLVMFunctionType::get(builder.getType(), @@ -100,6 +243,9 @@ void ConvertSnitchToLLVM::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.insert(typeConverter); + patterns.insert(dmaStart1D, typeConverter); + patterns.insert(dmaStart2D, typeConverter); + patterns.insert(dmaWait, typeConverter); patterns.insert(barrier, typeConverter); LLVMConversionTarget target(getContext()); diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/CMakeLists.txt b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/CMakeLists.txt index 1b2460e..23e1a60 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/CMakeLists.txt +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/CMakeLists.txt @@ -13,14 +13,18 @@ iree_cc_library( "QuidditchSnitchDialect.h.inc" "QuidditchSnitchOps.cpp.inc" "QuidditchSnitchOps.h.inc" + "QuidditchSnitchTypes.cpp.inc" + "QuidditchSnitchTypes.h.inc" SRCS "QuidditchSnitchAttrs.cpp" "QuidditchSnitchDialect.cpp" "QuidditchSnitchOps.cpp" + "QuidditchSnitchTypes.cpp" DEPS ::QuidditchSnitchAttrsGen ::QuidditchSnitchDialectGen ::QuidditchSnitchOpsGen + ::QuidditchSnitchTypesGen LLVMSupport MLIRIR MLIRInferTypeOpInterface @@ -48,7 +52,6 @@ iree_tablegen_library( --gen-dialect-defs QuidditchSnitchDialect.cpp.inc ) - iree_tablegen_library( NAME QuidditchSnitchAttrsGen @@ -58,3 +61,13 @@ iree_tablegen_library( --gen-attrdef-decls QuidditchSnitchAttrs.h.inc --gen-attrdef-defs QuidditchSnitchAttrs.cpp.inc ) + +iree_tablegen_library( + NAME + QuidditchSnitchTypesGen + TD_FILE + "QuidditchSnitchTypes.td" + OUTS + --gen-typedef-decls QuidditchSnitchTypes.h.inc + --gen-typedef-defs QuidditchSnitchTypes.cpp.inc +) diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp index 738d09b..c4b023b 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp @@ -2,6 +2,7 @@ #include "QuidditchSnitchAttrs.h" #include "QuidditchSnitchOps.h" +#include "QuidditchSnitchTypes.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp.inc" @@ -16,4 +17,8 @@ void QuidditchSnitchDialect::initialize() { #define GET_ATTRDEF_LIST #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.cpp.inc" >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp.inc" + >(); } diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td index 8c95d6c..175b6fc 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td @@ -13,6 +13,7 @@ def QuidditchSnitch_Dialect : Dialect { ); let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; } #endif diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h index c321a71..27c049e 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h @@ -9,5 +9,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "QuidditchSnitchTypes.h" + #define GET_OP_CLASSES #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h.inc" diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td index c2a1922..575ab2b 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td @@ -2,6 +2,7 @@ #define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHOPS include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td" +include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" include "mlir/IR/CommonTypeConstraints.td" include "mlir/IR/OpBase.td" @@ -121,6 +122,51 @@ def QuidditchSnitch_L1MemoryViewOp : QuidditchSnitch_Op<"l1_memory_view", }]; } +def QuidditchSnitch_StartDMATransferOp : QuidditchSnitch_Op<"start_dma_transfer", + [MemoryEffects<[MemWrite]>, SameOperandsElementType, SameOperandsShape]> { + + let description = [{ + Operation performing a DMA transfer from one MemRef to another. + At least one of the two MemRefs must be in L1 memory. + The shapes (including dynamic ones at runtime) of both MemRefs must be + identical with different strides and offsets allowed. + + The DMA operation is likely (but not guaranteed) to run asynchronous and + its completion only guaranteed by executing the `wait_for_dma_transfers` + operation with the token returned by this operation or a later one. + }]; + + // TODO: In reality what the constraint here is, is that both of them must either be contiguous or if a dimension is + // not contiguous in one of them, all dimensions prior to that must be contiguous (i.e. of equal size). + // Can support higher dimensions and funkier strides once needed. + let arguments = (ins + Arg, "source", [MemRead]>:$source, + Arg, "destination", [MemWrite]>:$dest + ); + + let results = (outs QuidditchSnitch_DMATokenType:$token); + + let assemblyFormat = [{ + `from` $source `:` type($source) `to` $dest `:` type($dest) attr-dict + }]; +} + +def QuidditchSnitch_WaitForDMATransfersOp + : QuidditchSnitch_Op<"wait_for_dma_transfers"> { + + let description = [{ + Operation awaiting for DMA transfers denoted by its tokens to be finished. + }]; + + let arguments = (ins + Variadic:$tokens + ); + + let assemblyFormat = [{ + $tokens `:` type($tokens) attr-dict + }]; +} + def QuidditchSnitch_BarrierOp : QuidditchSnitch_Op<"barrier"> { let assemblyFormat = [{ attr-dict diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp new file mode 100644 index 0000000..d4bcc5d --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp @@ -0,0 +1,11 @@ +#include "QuidditchSnitchTypes.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +#include "QuidditchSnitchDialect.h" + +#define GET_TYPEDEF_CLASSES +#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp.inc" diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.h b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.h new file mode 100644 index 0000000..9547332 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.h @@ -0,0 +1,7 @@ + +#pragma once + +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.h.inc" diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td new file mode 100644 index 0000000..42a3bfc --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td @@ -0,0 +1,18 @@ +#ifndef QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHTYPES +#define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHTYPES + +include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class QuidditchSnitch_Type traits = []> : + TypeDef; + +def QuidditchSnitch_DMATokenType : QuidditchSnitch_Type<"DMAToken"> { + let mnemonic = "dma_token"; + + let description = [{ + Type representing a potentially active DMA transfer. + }]; +} + +#endif diff --git a/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer_1d.mlir b/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer_1d.mlir new file mode 100644 index 0000000..5d6fd52 --- /dev/null +++ b/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer_1d.mlir @@ -0,0 +1,15 @@ +// RUN: quidditch-opt %s --quidditch-convert-snitch-to-llvm | FileCheck %s + +// CHECK-LABEL: @test +func.func @test(%arg0 : memref, %arg1 : memref) -> !quidditch_snitch.dma_token { + // CHECK: %[[ARG0_PTR:.*]] = llvm.extractvalue %{{.*}}[1] + // CHECK: %[[ARG1_PTR:.*]] = llvm.extractvalue %{{.*}}[1] + // CHECK: %[[ARG0_SIZE:.*]] = llvm.extractvalue %{{.*}}[3, 0] + // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ARG0_SIZE]]] + // CHECK: %[[SIZE:.*]] = llvm.ptrtoint %[[GEP]] + // CHECK: %[[R:.*]] = llvm.call @snrt_dma_start_1d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[SIZE]]) + %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref to %arg1 : memref + // CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[R]] + // CHECK: return %[[C]] + return %0 : !quidditch_snitch.dma_token +} diff --git a/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_wait.mlir b/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_wait.mlir new file mode 100644 index 0000000..2747d5e --- /dev/null +++ b/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_wait.mlir @@ -0,0 +1,9 @@ +// RUN: quidditch-opt %s --quidditch-convert-snitch-to-llvm | FileCheck %s + +// CHECK-LABEL: @test +func.func @test(%arg0 : !quidditch_snitch.dma_token) { + // TODO: This should be a call to snrt_dma_wait but is currently bugged. + // CHECK: call @snrt_dma_wait_all() + quidditch_snitch.wait_for_dma_transfers %arg0 : !quidditch_snitch.dma_token + return +}