From 1b4fd91f547857a3621a6837a7ff8eb1fb9f7ef9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Sat, 31 Aug 2024 22:37:02 +0100 Subject: [PATCH] [DMA] Split DMA operations into its own dialect (#127) `quidditch_snitch` has turned into a bit of a god-dialect which even contains things not specific to snitch. This PR is the first step in reducing this by splitting everything DMA related out of the dialect. Future goal is to have lowerings of just DMA operations to lower-level hardware specific dialects where hardware specific optimizations and legalizations can occur. --- .../src/Quidditch/Conversion/CMakeLists.txt | 15 + .../Quidditch/Conversion/ConvertDMAToLLVM.cpp | 484 ++++++++++++++++++ .../Quidditch/Conversion/ConvertDMAToLLVM.h | 10 + .../Conversion/ConvertSnitchToLLVM.cpp | 470 +---------------- .../src/Quidditch/Dialect/DMA/CMakeLists.txt | 1 + .../Dialect/DMA/Extensions/CMakeLists.txt | 12 + .../DMACoreSpecializationOpInterfaceImpl.cpp | 76 +++ .../DMACoreSpecializationOpInterfaceImpl.h | 10 + .../Quidditch/Dialect/DMA/IR/CMakeLists.txt | 73 +++ .../src/Quidditch/Dialect/DMA/IR/DMAAttrs.cpp | 1 + .../src/Quidditch/Dialect/DMA/IR/DMAAttrs.h | 7 + .../src/Quidditch/Dialect/DMA/IR/DMAAttrs.td | 20 + .../Quidditch/Dialect/DMA/IR/DMADialect.cpp | 44 ++ .../src/Quidditch/Dialect/DMA/IR/DMADialect.h | 7 + .../Quidditch/Dialect/DMA/IR/DMADialect.td | 15 + .../src/Quidditch/Dialect/DMA/IR/DMAOps.cpp | 434 ++++++++++++++++ .../src/Quidditch/Dialect/DMA/IR/DMAOps.h | 14 + .../src/Quidditch/Dialect/DMA/IR/DMAOps.td | 206 ++++++++ .../src/Quidditch/Dialect/DMA/IR/DMATypes.cpp | 11 + .../src/Quidditch/Dialect/DMA/IR/DMATypes.h | 7 + .../src/Quidditch/Dialect/DMA/IR/DMATypes.td | 18 + .../Dialect/Snitch/IR/QuidditchSnitchAttrs.td | 10 - .../Snitch/IR/QuidditchSnitchDialect.cpp | 10 - .../Snitch/IR/QuidditchSnitchDialect.td | 4 +- .../Dialect/Snitch/IR/QuidditchSnitchOps.cpp | 410 --------------- .../Dialect/Snitch/IR/QuidditchSnitchOps.td | 209 -------- .../Dialect/Snitch/IR/QuidditchSnitchTypes.td | 8 - .../Dialect/Snitch/Transforms/CMakeLists.txt | 1 + .../Dialect/Snitch/Transforms/Passes.td | 3 + .../Snitch/Transforms/PipelineCopyCompute.cpp | 3 + .../Dialect/Snitch/Transforms/PromoteToL1.cpp | 19 +- .../src/Quidditch/Target/CMakeLists.txt | 2 + .../src/Quidditch/Target/ConvertToLLVM.cpp | 2 + .../src/Quidditch/Target/QuidditchTarget.cpp | 22 +- .../dma_transfer.mlir | 52 +- .../dma_wait.mlir | 4 +- .../zero_mem_transfer.mlir | 12 +- .../ConvertSnitchToLLVM/completed_token.mlir | 6 +- .../tests/Dialect/DMA/IR/bufferization.mlir | 125 +++++ .../Dialect/DMA/IR/canonicalization.mlir | 111 ++++ codegen/tests/Dialect/DMA/IR/roundtrip.mlir | 11 + .../Dialect/Snitch/IR/bufferization.mlir | 124 ----- .../Dialect/Snitch/IR/canonicalization.mlir | 110 ---- .../tests/Dialect/Snitch/IR/roundtrip.mlir | 6 - .../Snitch/Transforms/lower-pipeline.mlir | 14 +- .../Transforms/pipeline-copy-compute.mlir | 24 +- .../Transforms/promote-operands-to-l1.mlir | 12 +- .../Snitch/Transforms/promote-pads-to-l1.mlir | 12 +- .../Transforms/specialize-dma-code.mlir | 46 +- codegen/tools/CMakeLists.txt | 1 + codegen/tools/quidditch-opt.cpp | 6 +- 51 files changed, 1849 insertions(+), 1465 deletions(-) create mode 100644 codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp create mode 100644 codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.h create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/CMakeLists.txt create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/CMakeLists.txt create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.cpp create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.h create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/CMakeLists.txt create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.cpp create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.h create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.td create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.cpp create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.h create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.td create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.h create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.cpp create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.h create mode 100644 codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.td rename codegen/tests/Conversion/{ConvertSnitchToLLVM => ConvertDMAToLLVM}/dma_transfer.mlir (79%) rename codegen/tests/Conversion/{ConvertSnitchToLLVM => ConvertDMAToLLVM}/dma_wait.mlir (79%) rename codegen/tests/Conversion/{ConvertSnitchToLLVM => ConvertDMAToLLVM}/zero_mem_transfer.mlir (88%) create mode 100644 codegen/tests/Dialect/DMA/IR/bufferization.mlir create mode 100644 codegen/tests/Dialect/DMA/IR/canonicalization.mlir create mode 100644 codegen/tests/Dialect/DMA/IR/roundtrip.mlir diff --git a/codegen/compiler/src/Quidditch/Conversion/CMakeLists.txt b/codegen/compiler/src/Quidditch/Conversion/CMakeLists.txt index 491a442..f475b9d 100644 --- a/codegen/compiler/src/Quidditch/Conversion/CMakeLists.txt +++ b/codegen/compiler/src/Quidditch/Conversion/CMakeLists.txt @@ -38,3 +38,18 @@ iree_cc_library( MLIRSCFDialect MLIRTransforms ) + +iree_cc_library( + NAME + ConvertDMAToLLVM + SRCS + "ConvertDMAToLLVM.cpp" + DEPS + Quidditch::Dialect::DMA::IR::DMADialect + MLIRAnalysis + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRSCFDialect + MLIRTransforms +) diff --git a/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp b/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp new file mode 100644 index 0000000..d29c81b --- /dev/null +++ b/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp @@ -0,0 +1,484 @@ +#include "ConvertDMAToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "Quidditch/Dialect/DMA/IR/DMAOps.h" + +using namespace mlir; +using namespace quidditch::dma; + +/// Returns the number of potentially non-contiguous outer dimensions of +/// 'memRefType'. The remaining inner dimensions (i.e. all dimensions at index +/// 'NonContiguousOuterDims' to the MemRef rank) are known to be contiguous. +/// Returns failure if the layout attribute of the MemRef is unsupported. +static FailureOr getNumNonContiguousOuterDims(MemRefType memRefType) { + auto stridesAttr = + dyn_cast_or_null(memRefType.getLayout()); + if (!stridesAttr) { + if (memRefType.getLayout() && !memRefType.getLayout().isIdentity()) + return failure(); + + // No layout or identity layouts are by definition fully contiguous. + return 0; + } + + int64_t innerSize = 1; + ArrayRef shape = memRefType.getShape(); + ArrayRef strides = stridesAttr.getStrides(); + for (; !shape.empty(); + shape = shape.drop_back(), strides = strides.drop_back()) { + int64_t dim = shape.back(); + // Unit dims can be dropped alongside the corresponding stride of that dim. + if (dim == 1) + continue; + + int64_t stride = strides.back(); + if (ShapedType::isDynamic(stride)) + break; + + if (innerSize != stride) + break; + + // Note: Dim may be dynamic with the value -1. This intentionally will only + // fail the 'if' above later if the outer dims are non-zero. + innerSize *= dim; + } + + return shape.size(); +} + +/// Returns true if this MemRef type is known to have a fully contiguous layout. +/// TODO: Could be upstreamed next to +/// 'memref::isStaticShapeAndContiguousRowMajor' +static bool isContiguous(MemRefType memRefType) { + return getNumNonContiguousOuterDims(memRefType) == 0; +} + +namespace { +struct StartTransferOp1DLowering : ConvertOpToLLVMPattern { + + LLVM::LLVMFuncOp dmaStart1DFunc; + + StartTransferOp1DLowering(LLVM::LLVMFuncOp dmaStart1DFunc, + const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter, /*benefit=*/2), + dmaStart1DFunc(dmaStart1DFunc) {} + + LogicalResult match(StartTransferOp op) const override { + return success(isContiguous(op.getSource().getType()) && + isContiguous(op.getDest().getType())); + } + + void rewrite(StartTransferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefDescriptor sourceDescriptor(adaptor.getSource()); + MemRefDescriptor destDescriptor(adaptor.getDest()); + + Value source = sourceDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType()); + Value dest = destDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), op.getDest().getType()); + + MemRefType sourceMemRef = op.getSource().getType(); + SmallVector dynamicSizes; + for (auto [index, dim] : llvm::enumerate(sourceMemRef.getShape())) + if (ShapedType::isDynamic(dim)) + dynamicSizes.push_back( + sourceDescriptor.size(rewriter, op->getLoc(), index)); + + SmallVector sizes; + SmallVector strides; + Value totalSize; + getMemRefDescriptorSizes( + op->getLoc(), + // Offsets are not considered an identity layout. + // Get rid of the layout entirely for the size calculation. + MemRefType::get(sourceMemRef.getShape(), sourceMemRef.getElementType(), + nullptr, sourceMemRef.getMemorySpace()), + dynamicSizes, rewriter, sizes, strides, totalSize); + + rewriter.replaceOpWithNewOp(op, dmaStart1DFunc, + ValueRange{ + dest, + source, + totalSize, + }); + } +}; + +struct StartTransferOp2DLowering : ConvertOpToLLVMPattern { + + LLVM::LLVMFuncOp dmaStart2DFunc; + + StartTransferOp2DLowering(LLVM::LLVMFuncOp dmaStart2DFunc, + const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter), dmaStart2DFunc(dmaStart2DFunc) {} + + LogicalResult + matchAndRewrite(StartTransferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefType sourceMemRef = op.getSource().getType(); + MemRefType destMemRef = op.getDest().getType(); + + // Compute the size of the contiguous inner loop common to both MemRefs and + // "shave" it off the ends of the shapes and strides. The remaining shapes + // and strides are considered our outer dimensions. + FailureOr sourceNonContiguous = + getNumNonContiguousOuterDims(sourceMemRef); + FailureOr destNonContiguous = + getNumNonContiguousOuterDims(destMemRef); + if (failed(sourceNonContiguous) || failed(destNonContiguous)) + return failure(); + size_t sharedNonContiguous = + std::max(*sourceNonContiguous, *destNonContiguous); + if (sharedNonContiguous == 0) + return failure(); + + Value elementSize = rewriter.create( + op->getLoc(), + rewriter.getI32IntegerAttr(llvm::divideCeil( + op.getSource().getType().getElementTypeBitWidth(), 8))); + SmallVector sizes = + memref::getMixedSizes(rewriter, op->getLoc(), op.getSource()); + + // Build a loop nest iterating over all outer dimensions - 1 and adjusts the + // source and destination pointers accordingly. The inner-most outer + // dimension is used in the DMA call for the repetition count and strides. + SmallVector lowerBounds; + SmallVector upperBounds; + SmallVector steps; + Value zeroIndex = rewriter.create(op.getLoc(), 0); + Value oneIndex = rewriter.create(op.getLoc(), 1); + for (size_t index : llvm::seq(sharedNonContiguous - 1)) { + lowerBounds.push_back(zeroIndex); + steps.push_back(oneIndex); + upperBounds.push_back(getValueOrCreateConstantIndexOp( + rewriter, op->getLoc(), sizes[index])); + } + + Value contiguousSize; + for (auto index : + llvm::seq(sharedNonContiguous, sourceMemRef.getRank())) { + Value dim = + getValueOrCreateConstantIndexOp(rewriter, op->getLoc(), sizes[index]); + if (!contiguousSize) { + contiguousSize = dim; + continue; + } + contiguousSize = + rewriter.create(op->getLoc(), contiguousSize, dim); + } + contiguousSize = typeConverter->materializeTargetConversion( + rewriter, op->getLoc(), getIndexType(), contiguousSize); + contiguousSize = + rewriter.create(op->getLoc(), contiguousSize, elementSize); + + Value completedToken = rewriter.create(op->getLoc()); + + scf::LoopNest loopNest = scf::buildLoopNest( + rewriter, op->getLoc(), lowerBounds, upperBounds, steps, completedToken, + [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange iterArgs) -> scf::ValueVector { + SmallVector offsets = ivs; + SmallVector subSizes(sharedNonContiguous - 1, + rewriter.getIndexAttr(1)); + for (unsigned i : llvm::seq(sharedNonContiguous - 1, + sourceMemRef.getRank())) { + offsets.push_back(rewriter.getIndexAttr(0)); + subSizes.push_back(sizes[i]); + } + SmallVector strides(sourceMemRef.getRank(), + rewriter.getIndexAttr(1)); + + TypedValue sourceMemRefSlice = + rewriter.create(loc, op.getSource(), offsets, + subSizes, strides); + TypedValue destMemRefSlice = + rewriter.create(loc, op.getDest(), offsets, + subSizes, strides); + + auto sourceDescriptor = + MemRefDescriptor(typeConverter->materializeTargetConversion( + rewriter, op->getLoc(), + typeConverter->convertType(sourceMemRefSlice.getType()), + sourceMemRefSlice)); + auto destDescriptor = + MemRefDescriptor(typeConverter->materializeTargetConversion( + rewriter, op->getLoc(), + typeConverter->convertType(destMemRefSlice.getType()), + destMemRefSlice)); + + Value sourceAdjusted = sourceDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), + sourceMemRefSlice.getType()); + Value destAdjusted = destDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), + destMemRefSlice.getType()); + + Value sourceStride = + sourceDescriptor.stride(builder, loc, sharedNonContiguous - 1); + sourceStride = rewriter.create( + op->getLoc(), sourceStride, elementSize); + Value destStride = + destDescriptor.stride(builder, loc, sharedNonContiguous - 1); + destStride = rewriter.create(op->getLoc(), destStride, + elementSize); + + Value outerLoopSize = + sourceDescriptor.size(builder, loc, sharedNonContiguous - 1); + return {builder + .create(loc, dmaStart2DFunc, + ValueRange{ + destAdjusted, + sourceAdjusted, + contiguousSize, + destStride, + sourceStride, + outerLoopSize, + }) + .getResult()}; + }); + + Type tokenType = typeConverter->convertType(op.getType()); + rewriter.replaceOp( + op, typeConverter->materializeTargetConversion( + rewriter, op->getLoc(), tokenType, loopNest.results.front())); + return success(); + } +}; + +// TODO: These should not be hardcoded. +constexpr unsigned zeroMemSize = 0x10000; +constexpr unsigned zeroMemAddress = 0x10030000; + +struct StartContiguousZeroMemTransferOpOpLowering + : ConvertOpToLLVMPattern { + + LLVM::LLVMFuncOp dmaStart1DFunc; + LLVM::LLVMFuncOp dmaStart2DFunc; + + StartContiguousZeroMemTransferOpOpLowering(LLVM::LLVMFuncOp dmaStart1DFunc, + LLVM::LLVMFuncOp dmaStart2DFunc, + const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter, /*benefit=*/2), + dmaStart1DFunc(dmaStart1DFunc), dmaStart2DFunc(dmaStart2DFunc) {} + + LogicalResult + matchAndRewrite(StartZeroMemTransferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isContiguous(op.getFilled().getType())) + return failure(); + + Value zeroPointer = rewriter.create( + op->getLoc(), rewriter.getType(), + rewriter.create( + op->getLoc(), rewriter.getI32IntegerAttr(zeroMemAddress))); + Value zeroMemSizeValue = rewriter.create( + op->getLoc(), rewriter.getI32IntegerAttr(zeroMemSize)); + + SmallVector sizes; + SmallVector strides; + Value size; + + auto filledDesc = MemRefDescriptor(adaptor.getFilled()); + + MemRefType memRefType = op.getFilled().getType(); + SmallVector dynamicSizes; + for (auto [index, shape] : llvm::enumerate(memRefType.getShape())) + if (ShapedType::isDynamic(shape)) + dynamicSizes.push_back(filledDesc.size(rewriter, op->getLoc(), index)); + + // Function does not support strided layout, even if it is contiguous. + // Lie about it and remove it. + // TODO: Consider fixing this upstream. + // TODO: Make a clone method of `MemRefType` that changes just the layout. + this->getMemRefDescriptorSizes( + op->getLoc(), + MemRefType::get(memRefType.getShape(), memRefType.getElementType()), + dynamicSizes, rewriter, sizes, strides, size); + + Value zero = + createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0); + Value bufferPointer = filledDesc.bufferPtr(rewriter, op->getLoc(), + *getTypeConverter(), memRefType); + Value times2D = + rewriter.create(op->getLoc(), size, zeroMemSizeValue); + // Note: This call would not be legal as a 'start_dma_transfer' call as + // MemRefs do not allow internal aliasing, which the below does via the + // stride of 0. + rewriter.create(op->getLoc(), dmaStart2DFunc, + ValueRange{bufferPointer, zeroPointer, + zeroMemSizeValue, zeroMemSizeValue, + zero, times2D}); + Value offset = + rewriter.create(op->getLoc(), times2D, zeroMemSizeValue); + bufferPointer = rewriter.create( + op->getLoc(), bufferPointer.getType(), rewriter.getI8Type(), + bufferPointer, offset); + Value rest = + rewriter.create(op->getLoc(), size, zeroMemSizeValue); + rewriter.replaceOpWithNewOp( + op, dmaStart1DFunc, ValueRange{bufferPointer, zeroPointer, rest}); + return success(); + } +}; + +struct StartZeroMemTransferOpOpLowering + : ConvertOpToLLVMPattern { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(StartZeroMemTransferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefType memRefType = op.getFilled().getType(); + + FailureOr nonContiguousDims = + getNumNonContiguousOuterDims(memRefType); + if (failed(nonContiguousDims) || nonContiguousDims == 0) + return failure(); + + SmallVector sizes = + memref::getMixedSizes(rewriter, op->getLoc(), op.getFilled()); + + SmallVector lowerBounds; + SmallVector upperBounds; + SmallVector steps; + Value zeroIndex = rewriter.create(op.getLoc(), 0); + Value oneIndex = rewriter.create(op.getLoc(), 1); + for (size_t index : llvm::seq(*nonContiguousDims)) { + lowerBounds.push_back(zeroIndex); + steps.push_back(oneIndex); + upperBounds.push_back(getValueOrCreateConstantIndexOp( + rewriter, op->getLoc(), sizes[index])); + } + + // Loop over every non-contiguous dimension to zero every contiguous + // inner subview. + Value completedToken = rewriter.create(op->getLoc()); + scf::LoopNest loopNest = scf::buildLoopNest( + rewriter, op->getLoc(), lowerBounds, upperBounds, steps, completedToken, + [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange iterArgs) -> scf::ValueVector { + SmallVector offsets = ivs; + SmallVector subSizes(*nonContiguousDims, + rewriter.getIndexAttr(1)); + for (unsigned i : + llvm::seq(*nonContiguousDims, memRefType.getRank())) { + offsets.push_back(rewriter.getIndexAttr(0)); + subSizes.push_back(sizes[i]); + } + SmallVector strides(memRefType.getRank(), + rewriter.getIndexAttr(1)); + + Value subMemRef = rewriter.create( + loc, op.getFilled(), offsets, subSizes, strides); + return { + builder.create(op->getLoc(), subMemRef)}; + }); + + Type tokenType = typeConverter->convertType(op.getType()); + rewriter.replaceOp( + op, typeConverter->materializeTargetConversion( + rewriter, op->getLoc(), tokenType, loopNest.results.front())); + return success(); + } +}; + +struct WaitForTransfersOpLowering : ConvertOpToLLVMPattern { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(WaitForTransfersOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getTokens().empty()) { + rewriter.eraseOp(op); + return success(); + } + + Value current = adaptor.getTokens().front(); + for (Value iter : llvm::drop_begin(adaptor.getTokens())) + current = rewriter.create(op->getLoc(), current, iter); + + Block *prev = op->getBlock(); + Block *body = rewriter.splitBlock(prev, op->getIterator()); + Block *after = rewriter.splitBlock(body, op->getNextNode()->getIterator()); + rewriter.setInsertionPointToEnd(prev); + rewriter.create(op->getLoc(), body); + + rewriter.setInsertionPointToEnd(body); + Value lastCompleted = + rewriter + .create( + op->getLoc(), /*res=*/rewriter.getI32Type(), + /*operands=*/ValueRange(), + // dmstati $0, 0 + // opcode6=0x2b, func3=0, func7=0b100, rd=$0, rs1=zero, + // rs2=imm5(0) + ".insn r 0x2b, 0, 0b100, $0, zero, zero\n", + /*constraints=*/"=r", + /*has_side_effects=*/true, /*is_align_stack=*/false, + /*asm_dialect=*/nullptr, /*operand_attrs=*/nullptr) + .getRes(); + Value notDone = rewriter.create( + op->getLoc(), LLVM::ICmpPredicate::ult, lastCompleted, current); + rewriter.create(op->getLoc(), notDone, body, after); + + rewriter.setInsertionPointToStart(after); + rewriter.eraseOp(op); + return success(); + } +}; + +struct CompletedTokenOpLowering : ConvertOpToLLVMPattern { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(CompletedTokenOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), 0); + return success(); + } +}; + +} // namespace + +void quidditch::populateDMAToLLVMConversionPatterns( + mlir::ModuleOp moduleOp, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) { + + typeConverter.addConversion( + [](TokenType token) { return IntegerType::get(token.getContext(), 32); }); + + auto builder = OpBuilder::atBlockEnd(moduleOp.getBody()); + auto ptrType = builder.getType(); + IntegerType i32 = builder.getI32Type(); + IntegerType sizeT = i32; + auto dmaStart1D = builder.create( + builder.getUnknownLoc(), "snrt_dma_start_1d", + LLVM::LLVMFunctionType::get(i32, + ArrayRef{ptrType, ptrType, sizeT})); + dmaStart1D->setAttr("hal.import.bitcode", builder.getUnitAttr()); + + auto dmaStart2D = builder.create( + builder.getUnknownLoc(), "snrt_dma_start_2d", + LLVM::LLVMFunctionType::get( + i32, ArrayRef{ptrType, ptrType, sizeT, sizeT, sizeT, sizeT})); + dmaStart2D->setAttr("hal.import.bitcode", builder.getUnitAttr()); + + patterns.insert(typeConverter); + patterns.insert(dmaStart1D, typeConverter); + patterns.insert(dmaStart2D, typeConverter); + patterns.insert( + dmaStart1D, dmaStart2D, typeConverter); +} diff --git a/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.h b/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.h new file mode 100644 index 0000000..cbd351e --- /dev/null +++ b/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.h @@ -0,0 +1,10 @@ + +#pragma once + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +namespace quidditch { +void populateDMAToLLVMConversionPatterns(mlir::ModuleOp moduleOp, + mlir::LLVMTypeConverter &converter, + mlir::RewritePatternSet &patterns); +} diff --git a/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp b/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp index 4f404b7..063bae6 100644 --- a/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp +++ b/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp @@ -44,447 +44,6 @@ struct L1MemoryViewOpLowering : ConvertOpToLLVMPattern { return success(); } }; -} // namespace - -/// Returns the number of potentially non-contiguous outer dimensions of -/// 'memRefType'. The remaining inner dimensions (i.e. all dimensions at index -/// 'NonContiguousOuterDims' to the MemRef rank) are known to be contiguous. -/// Returns failure if the layout attribute of the MemRef is unsupported. -static FailureOr getNumNonContiguousOuterDims(MemRefType memRefType) { - auto stridesAttr = - dyn_cast_or_null(memRefType.getLayout()); - if (!stridesAttr) { - if (memRefType.getLayout() && !memRefType.getLayout().isIdentity()) - return failure(); - - // No layout or identity layouts are by definition fully contiguous. - return 0; - } - - int64_t innerSize = 1; - ArrayRef shape = memRefType.getShape(); - ArrayRef strides = stridesAttr.getStrides(); - for (; !shape.empty(); - shape = shape.drop_back(), strides = strides.drop_back()) { - int64_t dim = shape.back(); - // Unit dims can be dropped alongside the corresponding stride of that dim. - if (dim == 1) - continue; - - int64_t stride = strides.back(); - if (ShapedType::isDynamic(stride)) - break; - - if (innerSize != stride) - break; - - // Note: Dim may be dynamic with the value -1. This intentionally will only - // fail the 'if' above later if the outer dims are non-zero. - innerSize *= dim; - } - - return shape.size(); -} - -/// Returns true if this MemRef type is known to have a fully contiguous layout. -/// TODO: Could be upstreamed next to -/// 'memref::isStaticShapeAndContiguousRowMajor' -static bool isContiguous(MemRefType memRefType) { - return getNumNonContiguousOuterDims(memRefType) == 0; -} - -namespace { -struct StartDMATransferOp1DLowering - : ConvertOpToLLVMPattern { - - LLVM::LLVMFuncOp dmaStart1DFunc; - - StartDMATransferOp1DLowering(LLVM::LLVMFuncOp dmaStart1DFunc, - const LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter, /*benefit=*/2), - dmaStart1DFunc(dmaStart1DFunc) {} - - LogicalResult match(StartDMATransferOp op) const override { - return success(isContiguous(op.getSource().getType()) && - isContiguous(op.getDest().getType())); - } - - void rewrite(StartDMATransferOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemRefDescriptor sourceDescriptor(adaptor.getSource()); - MemRefDescriptor destDescriptor(adaptor.getDest()); - - Value source = sourceDescriptor.bufferPtr( - rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType()); - Value dest = destDescriptor.bufferPtr( - rewriter, op->getLoc(), *getTypeConverter(), op.getDest().getType()); - - MemRefType sourceMemRef = op.getSource().getType(); - SmallVector dynamicSizes; - for (auto [index, dim] : llvm::enumerate(sourceMemRef.getShape())) - if (ShapedType::isDynamic(dim)) - dynamicSizes.push_back( - sourceDescriptor.size(rewriter, op->getLoc(), index)); - - SmallVector sizes; - SmallVector strides; - Value totalSize; - getMemRefDescriptorSizes( - op->getLoc(), - // Offsets are not considered an identity layout. - // Get rid of the layout entirely for the size calculation. - MemRefType::get(sourceMemRef.getShape(), sourceMemRef.getElementType(), - nullptr, sourceMemRef.getMemorySpace()), - dynamicSizes, rewriter, sizes, strides, totalSize); - - rewriter.replaceOpWithNewOp(op, dmaStart1DFunc, - ValueRange{ - dest, - source, - totalSize, - }); - } -}; - -struct StartDMATransferOp2DLowering - : ConvertOpToLLVMPattern { - - LLVM::LLVMFuncOp dmaStart2DFunc; - - StartDMATransferOp2DLowering(LLVM::LLVMFuncOp dmaStart2DFunc, - const LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter), dmaStart2DFunc(dmaStart2DFunc) {} - - LogicalResult - matchAndRewrite(StartDMATransferOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemRefType sourceMemRef = op.getSource().getType(); - MemRefType destMemRef = op.getDest().getType(); - - // Compute the size of the contiguous inner loop common to both MemRefs and - // "shave" it off the ends of the shapes and strides. The remaining shapes - // and strides are considered our outer dimensions. - FailureOr sourceNonContiguous = - getNumNonContiguousOuterDims(sourceMemRef); - FailureOr destNonContiguous = - getNumNonContiguousOuterDims(destMemRef); - if (failed(sourceNonContiguous) || failed(destNonContiguous)) - return failure(); - size_t sharedNonContiguous = - std::max(*sourceNonContiguous, *destNonContiguous); - if (sharedNonContiguous == 0) - return failure(); - - Value elementSize = rewriter.create( - op->getLoc(), - rewriter.getI32IntegerAttr(llvm::divideCeil( - op.getSource().getType().getElementTypeBitWidth(), 8))); - SmallVector sizes = - memref::getMixedSizes(rewriter, op->getLoc(), op.getSource()); - - // Build a loop nest iterating over all outer dimensions - 1 and adjusts the - // source and destination pointers accordingly. The inner-most outer - // dimension is used in the DMA call for the repetition count and strides. - SmallVector lowerBounds; - SmallVector upperBounds; - SmallVector steps; - Value zeroIndex = rewriter.create(op.getLoc(), 0); - Value oneIndex = rewriter.create(op.getLoc(), 1); - for (size_t index : llvm::seq(sharedNonContiguous - 1)) { - lowerBounds.push_back(zeroIndex); - steps.push_back(oneIndex); - upperBounds.push_back(getValueOrCreateConstantIndexOp( - rewriter, op->getLoc(), sizes[index])); - } - - Value contiguousSize; - for (auto index : - llvm::seq(sharedNonContiguous, sourceMemRef.getRank())) { - Value dim = - getValueOrCreateConstantIndexOp(rewriter, op->getLoc(), sizes[index]); - if (!contiguousSize) { - contiguousSize = dim; - continue; - } - contiguousSize = - rewriter.create(op->getLoc(), contiguousSize, dim); - } - contiguousSize = typeConverter->materializeTargetConversion( - rewriter, op->getLoc(), getIndexType(), contiguousSize); - contiguousSize = - rewriter.create(op->getLoc(), contiguousSize, elementSize); - - Value completedToken = rewriter.create(op->getLoc()); - - scf::LoopNest loopNest = scf::buildLoopNest( - rewriter, op->getLoc(), lowerBounds, upperBounds, steps, completedToken, - [&](OpBuilder &builder, Location loc, ValueRange ivs, - ValueRange iterArgs) -> scf::ValueVector { - SmallVector offsets = ivs; - SmallVector subSizes(sharedNonContiguous - 1, - rewriter.getIndexAttr(1)); - for (unsigned i : llvm::seq(sharedNonContiguous - 1, - sourceMemRef.getRank())) { - offsets.push_back(rewriter.getIndexAttr(0)); - subSizes.push_back(sizes[i]); - } - SmallVector strides(sourceMemRef.getRank(), - rewriter.getIndexAttr(1)); - - TypedValue sourceMemRefSlice = - rewriter.create(loc, op.getSource(), offsets, - subSizes, strides); - TypedValue destMemRefSlice = - rewriter.create(loc, op.getDest(), offsets, - subSizes, strides); - - auto sourceDescriptor = - MemRefDescriptor(typeConverter->materializeTargetConversion( - rewriter, op->getLoc(), - typeConverter->convertType(sourceMemRefSlice.getType()), - sourceMemRefSlice)); - auto destDescriptor = - MemRefDescriptor(typeConverter->materializeTargetConversion( - rewriter, op->getLoc(), - typeConverter->convertType(destMemRefSlice.getType()), - destMemRefSlice)); - - Value sourceAdjusted = sourceDescriptor.bufferPtr( - rewriter, op->getLoc(), *getTypeConverter(), - sourceMemRefSlice.getType()); - Value destAdjusted = destDescriptor.bufferPtr( - rewriter, op->getLoc(), *getTypeConverter(), - destMemRefSlice.getType()); - - Value sourceStride = - sourceDescriptor.stride(builder, loc, sharedNonContiguous - 1); - sourceStride = rewriter.create( - op->getLoc(), sourceStride, elementSize); - Value destStride = - destDescriptor.stride(builder, loc, sharedNonContiguous - 1); - destStride = rewriter.create(op->getLoc(), destStride, - elementSize); - - Value outerLoopSize = - sourceDescriptor.size(builder, loc, sharedNonContiguous - 1); - return {builder - .create(loc, dmaStart2DFunc, - ValueRange{ - destAdjusted, - sourceAdjusted, - contiguousSize, - destStride, - sourceStride, - outerLoopSize, - }) - .getResult()}; - }); - - Type tokenType = typeConverter->convertType(op.getType()); - rewriter.replaceOp( - op, typeConverter->materializeTargetConversion( - rewriter, op->getLoc(), tokenType, loopNest.results.front())); - return success(); - } -}; - -// TODO: These should not be hardcoded. -constexpr unsigned zeroMemSize = 0x10000; -constexpr unsigned zeroMemAddress = 0x10030000; - -struct StartContiguousZeroMemTransferOpOpLowering - : ConvertOpToLLVMPattern { - - LLVM::LLVMFuncOp dmaStart1DFunc; - LLVM::LLVMFuncOp dmaStart2DFunc; - - StartContiguousZeroMemTransferOpOpLowering(LLVM::LLVMFuncOp dmaStart1DFunc, - LLVM::LLVMFuncOp dmaStart2DFunc, - const LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter, /*benefit=*/2), - dmaStart1DFunc(dmaStart1DFunc), dmaStart2DFunc(dmaStart2DFunc) {} - - LogicalResult - matchAndRewrite(StartZeroMemTransferOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isContiguous(op.getFilled().getType())) - return failure(); - - Value zeroPointer = rewriter.create( - op->getLoc(), rewriter.getType(), - rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(zeroMemAddress))); - Value zeroMemSizeValue = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(zeroMemSize)); - - SmallVector sizes; - SmallVector strides; - Value size; - - auto filledDesc = MemRefDescriptor(adaptor.getFilled()); - - MemRefType memRefType = op.getFilled().getType(); - SmallVector dynamicSizes; - for (auto [index, shape] : llvm::enumerate(memRefType.getShape())) - if (ShapedType::isDynamic(shape)) - dynamicSizes.push_back(filledDesc.size(rewriter, op->getLoc(), index)); - - // Function does not support strided layout, even if it is contiguous. - // Lie about it and remove it. - // TODO: Consider fixing this upstream. - // TODO: Make a clone method of `MemRefType` that changes just the layout. - this->getMemRefDescriptorSizes( - op->getLoc(), - MemRefType::get(memRefType.getShape(), memRefType.getElementType()), - dynamicSizes, rewriter, sizes, strides, size); - - Value zero = - createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0); - Value bufferPointer = filledDesc.bufferPtr(rewriter, op->getLoc(), - *getTypeConverter(), memRefType); - Value times2D = - rewriter.create(op->getLoc(), size, zeroMemSizeValue); - // Note: This call would not be legal as a 'start_dma_transfer' call as - // MemRefs do not allow internal aliasing, which the below does via the - // stride of 0. - rewriter.create(op->getLoc(), dmaStart2DFunc, - ValueRange{bufferPointer, zeroPointer, - zeroMemSizeValue, zeroMemSizeValue, - zero, times2D}); - Value offset = - rewriter.create(op->getLoc(), times2D, zeroMemSizeValue); - bufferPointer = rewriter.create( - op->getLoc(), bufferPointer.getType(), rewriter.getI8Type(), - bufferPointer, offset); - Value rest = - rewriter.create(op->getLoc(), size, zeroMemSizeValue); - rewriter.replaceOpWithNewOp( - op, dmaStart1DFunc, ValueRange{bufferPointer, zeroPointer, rest}); - return success(); - } -}; - -struct StartZeroMemTransferOpOpLowering - : ConvertOpToLLVMPattern { - - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(StartZeroMemTransferOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemRefType memRefType = op.getFilled().getType(); - - FailureOr nonContiguousDims = - getNumNonContiguousOuterDims(memRefType); - if (failed(nonContiguousDims) || nonContiguousDims == 0) - return failure(); - - SmallVector sizes = - memref::getMixedSizes(rewriter, op->getLoc(), op.getFilled()); - - SmallVector lowerBounds; - SmallVector upperBounds; - SmallVector steps; - Value zeroIndex = rewriter.create(op.getLoc(), 0); - Value oneIndex = rewriter.create(op.getLoc(), 1); - for (size_t index : llvm::seq(*nonContiguousDims)) { - lowerBounds.push_back(zeroIndex); - steps.push_back(oneIndex); - upperBounds.push_back(getValueOrCreateConstantIndexOp( - rewriter, op->getLoc(), sizes[index])); - } - - // Loop over every non-contiguous dimension to zero every contiguous - // inner subview. - Value completedToken = rewriter.create(op->getLoc()); - scf::LoopNest loopNest = scf::buildLoopNest( - rewriter, op->getLoc(), lowerBounds, upperBounds, steps, completedToken, - [&](OpBuilder &builder, Location loc, ValueRange ivs, - ValueRange iterArgs) -> scf::ValueVector { - SmallVector offsets = ivs; - SmallVector subSizes(*nonContiguousDims, - rewriter.getIndexAttr(1)); - for (unsigned i : - llvm::seq(*nonContiguousDims, memRefType.getRank())) { - offsets.push_back(rewriter.getIndexAttr(0)); - subSizes.push_back(sizes[i]); - } - SmallVector strides(memRefType.getRank(), - rewriter.getIndexAttr(1)); - - Value subMemRef = rewriter.create( - loc, op.getFilled(), offsets, subSizes, strides); - return { - builder.create(op->getLoc(), subMemRef)}; - }); - - Type tokenType = typeConverter->convertType(op.getType()); - rewriter.replaceOp( - op, typeConverter->materializeTargetConversion( - rewriter, op->getLoc(), tokenType, loopNest.results.front())); - return success(); - } -}; - -struct WaitForDMATransfersOpLowering - : ConvertOpToLLVMPattern { - - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(WaitForDMATransfersOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (adaptor.getTokens().empty()) { - rewriter.eraseOp(op); - return success(); - } - - Value current = adaptor.getTokens().front(); - for (Value iter : llvm::drop_begin(adaptor.getTokens())) - current = rewriter.create(op->getLoc(), current, iter); - - Block *prev = op->getBlock(); - Block *body = rewriter.splitBlock(prev, op->getIterator()); - Block *after = rewriter.splitBlock(body, op->getNextNode()->getIterator()); - rewriter.setInsertionPointToEnd(prev); - rewriter.create(op->getLoc(), body); - - rewriter.setInsertionPointToEnd(body); - Value lastCompleted = - rewriter - .create( - op->getLoc(), /*res=*/rewriter.getI32Type(), - /*operands=*/ValueRange(), - // dmstati $0, 0 - // opcode6=0x2b, func3=0, func7=0b100, rd=$0, rs1=zero, - // rs2=imm5(0) - ".insn r 0x2b, 0, 0b100, $0, zero, zero\n", - /*constraints=*/"=r", - /*has_side_effects=*/true, /*is_align_stack=*/false, - /*asm_dialect=*/nullptr, /*operand_attrs=*/nullptr) - .getRes(); - Value notDone = rewriter.create( - op->getLoc(), LLVM::ICmpPredicate::ult, lastCompleted, current); - rewriter.create(op->getLoc(), notDone, body, after); - - rewriter.setInsertionPointToStart(after); - rewriter.eraseOp(op); - return success(); - } -}; - -struct CompletedTokenOpLowering : ConvertOpToLLVMPattern { - - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(CompletedTokenOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), 0); - return success(); - } -}; struct BarrierOpLowering : ConvertOpToLLVMPattern { @@ -623,40 +182,15 @@ void quidditch::populateSnitchToLLVMConversionPatterns( mlir::ModuleOp moduleOp, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { - typeConverter.addConversion([](DMATokenType token) { - return IntegerType::get(token.getContext(), 32); - }); - auto builder = OpBuilder::atBlockEnd(moduleOp.getBody()); - auto ptrType = builder.getType(); IntegerType i32 = builder.getI32Type(); - IntegerType sizeT = i32; - auto dmaStart1D = builder.create( - builder.getUnknownLoc(), "snrt_dma_start_1d", - LLVM::LLVMFunctionType::get(i32, - ArrayRef{ptrType, ptrType, sizeT})); - dmaStart1D->setAttr("hal.import.bitcode", builder.getUnitAttr()); - - auto dmaStart2D = builder.create( - builder.getUnknownLoc(), "snrt_dma_start_2d", - LLVM::LLVMFunctionType::get( - i32, ArrayRef{ptrType, ptrType, sizeT, sizeT, sizeT, sizeT})); - dmaStart2D->setAttr("hal.import.bitcode", builder.getUnitAttr()); - auto computeCoreIndex = builder.create( builder.getUnknownLoc(), "snrt_cluster_core_idx", LLVM::LLVMFunctionType::get(i32, ArrayRef{})); computeCoreIndex->setAttr("hal.import.bitcode", builder.getUnitAttr()); - patterns - .insert( - typeConverter); - patterns.insert(dmaStart1D, typeConverter); - patterns.insert(dmaStart2D, typeConverter); - patterns.insert( - dmaStart1D, dmaStart2D, typeConverter); + patterns.insert(typeConverter); patterns.insert(computeCoreIndex, typeConverter); patterns.insert(SymbolTable(moduleOp), typeConverter); diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/CMakeLists.txt b/codegen/compiler/src/Quidditch/Dialect/DMA/CMakeLists.txt new file mode 100644 index 0000000..0e9c88b --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/CMakeLists.txt @@ -0,0 +1 @@ +iree_add_all_subdirs() diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/CMakeLists.txt b/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/CMakeLists.txt new file mode 100644 index 0000000..2333a5b --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/CMakeLists.txt @@ -0,0 +1,12 @@ + +iree_cc_library( + NAME + DMACoreSpecializationOpInterfaceImpl + HDRS + "DMACoreSpecializationOpInterfaceImpl.h" + SRCS + "DMACoreSpecializationOpInterfaceImpl.cpp" + DEPS + Quidditch::Dialect::Snitch::IR::QuidditchSnitchDialect + Quidditch::Dialect::DMA::IR::DMADialect +) diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.cpp b/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.cpp new file mode 100644 index 0000000..1bb4cb0 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.cpp @@ -0,0 +1,76 @@ +#include "DMACoreSpecializationOpInterfaceImpl.h" + +#include "Quidditch/Dialect/DMA/IR/DMADialect.h" +#include "Quidditch/Dialect/DMA/IR/DMAOps.h" +#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchInterfaces.h" +#include "mlir/IR/DialectRegistry.h" + +using namespace mlir; +using namespace quidditch::dma; +using namespace quidditch::Snitch; + +namespace { + +//===----------------------------------------------------------------------===// +// StartTransferOp::DMACoreSpecializationOpInterface +//===----------------------------------------------------------------------===// + +struct StartTransferOpImpl + : CoreSpecializationOpInterface::ExternalModel { + void replaceWithNoop(Operation *op, RewriterBase &rewriter) const { + rewriter.replaceOpWithNewOp(op); + } +}; + +struct StartTransferOpDMAImpl + : DMACoreSpecializationOpInterface::ExternalModel {}; + +//===----------------------------------------------------------------------===// +// StartZeroMemTransferOp::DMACoreSpecializationOpInterface +//===----------------------------------------------------------------------===// + +struct StartZeroMemTransferOpImpl + : CoreSpecializationOpInterface::ExternalModel { + void replaceWithNoop(Operation *op, RewriterBase &rewriter) const { + rewriter.replaceOpWithNewOp(op); + } + + // bool needsSynchronization(Operation *op) const { return true; } +}; + +struct StartZeroMemTransferOpDMAImpl + : DMACoreSpecializationOpInterface::ExternalModel< + StartZeroMemTransferOpDMAImpl, StartZeroMemTransferOp> {}; + +//===----------------------------------------------------------------------===// +// WaitForTransfersOpImpl::DMACoreSpecializationOpInterface +//===----------------------------------------------------------------------===// + +struct WaitForTransfersOpImpl + : CoreSpecializationOpInterface::ExternalModel { + void replaceWithNoop(Operation *op, RewriterBase &rewriter) const { + rewriter.eraseOp(op); + } + + bool needsSynchronization(Operation *op) const { return true; } +}; + +struct WaitForTransfersOpDMAImpl + : DMACoreSpecializationOpInterface::ExternalModel {}; + +} // namespace + +void quidditch::dma::registerDMACoreSpecializationOpInterface( + mlir::DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, DMADialect *dialect) { +#define REGISTER_IMPLS(Op) Op::attachInterface(*context) + REGISTER_IMPLS(StartTransferOp); + REGISTER_IMPLS(StartZeroMemTransferOp); + REGISTER_IMPLS(WaitForTransfersOp); + }); +} diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.h b/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.h new file mode 100644 index 0000000..28f67f7 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.h @@ -0,0 +1,10 @@ + +#pragma once + +namespace mlir { +class DialectRegistry; +} + +namespace quidditch::dma { +void registerDMACoreSpecializationOpInterface(mlir::DialectRegistry ®istry); +} diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/CMakeLists.txt b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/CMakeLists.txt new file mode 100644 index 0000000..fee69c0 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/CMakeLists.txt @@ -0,0 +1,73 @@ +iree_add_all_subdirs() + +iree_cc_library( + NAME + DMADialect + HDRS + "DMADialect.h" + "DMAOps.h" + TEXTUAL_HDRS + "DMAAttrs.cpp.inc" + "DMAAttrs.h.inc" + "DMADialect.cpp.inc" + "DMADialect.h.inc" + "DMAOps.cpp.inc" + "DMAOps.h.inc" + "DMATypes.cpp.inc" + "DMATypes.h.inc" + SRCS + "DMAAttrs.cpp" + "DMADialect.cpp" + "DMAOps.cpp" + "DMATypes.cpp" + DEPS + ::DMAAttrsGen + ::DMADialectGen + ::DMAOpsGen + ::DMATypesGen + LLVMSupport + MLIRIR + MLIRInferTypeOpInterface + MLIRSupport + PUBLIC +) + +iree_tablegen_library( + NAME + DMAOpsGen + TD_FILE + "DMAOps.td" + OUTS + --gen-op-decls DMAOps.h.inc + --gen-op-defs DMAOps.cpp.inc +) + +iree_tablegen_library( + NAME + DMADialectGen + TD_FILE + "DMADialect.td" + OUTS + --gen-dialect-decls DMADialect.h.inc + --gen-dialect-defs DMADialect.cpp.inc +) + +iree_tablegen_library( + NAME + DMAAttrsGen + TD_FILE + "DMAAttrs.td" + OUTS + --gen-attrdef-decls DMAAttrs.h.inc + --gen-attrdef-defs DMAAttrs.cpp.inc +) + +iree_tablegen_library( + NAME + DMATypesGen + TD_FILE + "DMATypes.td" + OUTS + --gen-typedef-decls DMATypes.h.inc + --gen-typedef-defs DMATypes.cpp.inc +) diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.cpp b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.cpp new file mode 100644 index 0000000..ced2b88 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.cpp @@ -0,0 +1 @@ +#include "DMAAttrs.h" diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.h b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.h new file mode 100644 index 0000000..1b8b88e --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.h @@ -0,0 +1,7 @@ + +#pragma once + +#include "mlir/IR/Attributes.h" + +#define GET_ATTRDEF_CLASSES +#include "Quidditch/Dialect/DMA/IR/DMAAttrs.h.inc" diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.td b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.td new file mode 100644 index 0000000..1fb3895 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAAttrs.td @@ -0,0 +1,20 @@ +#ifndef QUIDDITCH_DIALECT_DMA_DMAATTRS +#define QUIDDITCH_DIALECT_DMA_DMAATTRS + +include "Quidditch/Dialect/DMA/IR/DMADialect.td" +include "mlir/IR/AttrTypeBase.td" + +class DMA_Attr traits = []> : + AttrDef; + +def DMA_CompletedTokenAttr : DMA_Attr<"CompletedToken"> { + + let mnemonic = "completed_token"; + + let description = [{ + Attribute representing an instance of a `!dma.token` + signaling a complete transfer. + }]; +} + +#endif diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.cpp b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.cpp new file mode 100644 index 0000000..058947f --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.cpp @@ -0,0 +1,44 @@ +#include "DMADialect.h" + +#include "DMAAttrs.h" +#include "DMAOps.h" +#include "DMATypes.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +#define GET_ATTRDEF_CLASSES +#include "Quidditch/Dialect/DMA/IR/DMAAttrs.cpp.inc" + +#include "Quidditch/Dialect/DMA/IR/DMADialect.cpp.inc" + +using namespace mlir; +using namespace quidditch::dma; + +//===----------------------------------------------------------------------===// +// DMADialect +//===----------------------------------------------------------------------===// + +void DMADialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "Quidditch/Dialect/DMA/IR/DMAOps.cpp.inc" + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "Quidditch/Dialect/DMA/IR/DMAAttrs.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "Quidditch/Dialect/DMA/IR/DMATypes.cpp.inc" + >(); +} + +Operation *DMADialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + if (isa(value)) + return builder.create(loc); + + return nullptr; +} diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.h b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.h new file mode 100644 index 0000000..9f4a0e2 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.h @@ -0,0 +1,7 @@ + +#pragma once + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +#include "Quidditch/Dialect/DMA/IR/DMADialect.h.inc" diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.td b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.td new file mode 100644 index 0000000..69afe78 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMADialect.td @@ -0,0 +1,15 @@ +#ifndef QUIDDITCH_DIALECT_SNITCH_DMADIALECT +#define QUIDDITCH_DIALECT_SNITCH_DMADIALECT + +include "mlir/IR/DialectBase.td" + +def DMA_Dialect : Dialect { + let name = "dma"; + let cppNamespace = "::quidditch::dma"; + + let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; + let hasConstantMaterializer = 1; +} + +#endif diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp new file mode 100644 index 0000000..1208ad0 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp @@ -0,0 +1,434 @@ +#include "DMAOps.h" + +#include "llvm/ADT/ScopeExit.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" + +#include "DMAAttrs.h" + +static mlir::ParseResult +parseTensorCopyTypes(mlir::OpAsmParser &parser, + mlir::DenseI64ArrayAttr staticHighPad, + mlir::Type ©Type, mlir::Type &resultType); + +static void printTensorCopyTypes(mlir::OpAsmPrinter &printer, mlir::Operation *, + mlir::DenseI64ArrayAttr staticHighPad, + mlir::Type copyType, mlir::Type resultType); + +#define GET_OP_CLASSES +#include "Quidditch/Dialect/DMA/IR/DMAOps.cpp.inc" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace quidditch::dma; + +//===----------------------------------------------------------------------===// +// StartTensorCopyOp +//===----------------------------------------------------------------------===// + +ParseResult parseTensorCopyTypes(OpAsmParser &parser, + DenseI64ArrayAttr staticHighPad, + Type ©Type, Type &resultType) { + if (staticHighPad && !staticHighPad.empty()) { + if (parser.parseColon() || parser.parseType(copyType)) + return failure(); + } + if (parser.parseArrow() || parser.parseType(resultType)) + return failure(); + if (!staticHighPad || staticHighPad.empty()) + copyType = resultType; + return success(); +} + +static void printTensorCopyTypes(OpAsmPrinter &printer, mlir::Operation *, + DenseI64ArrayAttr staticHighPad, Type copyType, + Type resultType) { + if (staticHighPad && !staticHighPad.empty()) + printer << ": " << copyType; + printer << " -> " << resultType; +} + +LogicalResult StartTensorCopyOp::verify() { + if (getStaticHighPadAttr()) + if (getStaticHighPadAttr().size() != getCopy().getType().getRank()) + return emitOpError("expected padding number for every dimension"); + + unsigned numDynamicPads = llvm::count( + getStaticHighPad().value_or(std::nullopt), ShapedType::kDynamic); + if (numDynamicPads != getHighPad().size()) + return emitOpError("expected ") + << numDynamicPads << " dynamic padding values"; + + return success(); +} + +LogicalResult StartTensorCopyOp::fold(FoldAdaptor adaptor, + SmallVectorImpl &results) { + if (hasPadding()) { + // Remove noop padding. + if (llvm::all_of(getStaticHighPadAttr().asArrayRef(), + [](int64_t value) { return value == 0; })) { + removeStaticHighPadAttr(); + return success(); + } + + // Fold dynamic indices with constant values into the static list. + { + bool changed = false; + SmallVector padding = + llvm::to_vector(getStaticHighPadAttr().asArrayRef()); + unsigned dynamicIndex = 0; + for (int64_t &value : padding) { + if (!ShapedType::isDynamic(value)) + continue; + + if (auto integer = dyn_cast_or_null( + adaptor.getHighPad()[dynamicIndex])) { + value = integer.getValue().getZExtValue(); + getHighPadMutable().erase(dynamicIndex); + changed = true; + } else { + dynamicIndex++; + } + } + if (changed) { + setStaticHighPad(padding); + return success(); + } + } + } + + auto waitOp = getCopy().getDefiningOp(); + if (!waitOp) + return failure(); + auto copyOp = waitOp.getTransferTensor().getDefiningOp(); + if (!copyOp) + return failure(); + + if (hasPadding() && + (copyOp.getStaticHighPadAttr() != getStaticHighPadAttr() || + copyOp.getHighPad() != getHighPad())) + return failure(); + + results.emplace_back(waitOp); + results.emplace_back(CompletedTokenAttr::get(getContext())); + return success(); +} + +SmallVector StartTensorCopyOp::getMixedHighPad() { + Builder builder(getContext()); + if (!hasPadding()) + return SmallVector(getResult().getType().getRank(), + builder.getIndexAttr(0)); + + return getMixedValues(getStaticHighPadAttr().asArrayRef(), getHighPad(), + builder); +} + +//===----------------------------------------------------------------------===// +// StartTensorCopyOp::BufferizableOpInterface +//===----------------------------------------------------------------------===// + +/// Returns whether the allocation can be elided entirely. +/// Returns an empty optional if it was not possible to determine. +std::optional StartTensorCopyOp::elidesAllocation( + const bufferization::BufferizationOptions &options, + SmallVector *invocationStack) { + // Padding cannot be elided in general, even if the copied buffer is in L1. + if (hasPadding()) + return false; + + FailureOr copyType = + invocationStack + ? bufferization::getBufferType(getCopy(), options, *invocationStack) + : bufferization::getBufferType(getCopy(), options); + if (failed(copyType)) + return std::nullopt; + + return copyType->getMemorySpace() == getMemorySpaceAttr(); +} + +bool StartTensorCopyOp::resultBufferizesToMemoryWrite( + OpResult opResult, const bufferization::AnalysisState &state) { + assert(opResult == getResult() && "no other result"); + + std::optional matches = elidesAllocation(state.getOptions()); + // Conservative answer. + if (!matches) + return true; + + // No copy is performed unless the address space does not match. + // Copy in this context implies that we are writing to the result. + return !*matches; +} + +bool StartTensorCopyOp::bufferizesToMemoryRead( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + assert(opOperand == getCopyMutable() && "have only one operand"); + + std::optional result = elidesAllocation(state.getOptions()); + // Conservative answer. + if (!result) + return true; + + // We only read from the buffer if we are copying. + return !*result; +} + +bool StartTensorCopyOp::bufferizesToMemoryWrite( + OpOperand &opOperand, const bufferization::AnalysisState &) { + assert(opOperand == getCopyMutable() && "have only one operand"); + + // We do not write into the buffer we are copying ever. + return false; +} + +AliasingValueList StartTensorCopyOp::getAliasingValues( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + assert(opOperand == getCopyMutable() && "have only one operand"); + + std::optional result = elidesAllocation(state.getOptions()); + if (!result) + // Assume the worst case. + return {{getResult(), BufferRelation::Equivalent, /*isDefinite=*/false}}; + + // Always a brand-new allocation unless the input buffer is already in L1 and + // we elide the copy, in which case operand and result alias. + if (*result) + return {{getResult(), BufferRelation::Equivalent, /*isDefinite=*/true}}; + + return {}; +} + +bool StartTensorCopyOp::bufferizesToAllocation(Value value) { + assert(value == getResult() && "have only one result"); + + if (elidesAllocation() == true) + return false; + + // True is the conservative reply, according to the docs. + return true; +} + +FailureOr +StartTensorCopyOp::getBufferType(Value value, + const BufferizationOptions &options, + SmallVector &invocationStack) { + assert(value == getResult() && "have only one result"); + + bool contained = llvm::is_contained(invocationStack, value); + if (!contained) + if (elidesAllocation(options, &invocationStack) == true) + return bufferization::getBufferType(getCopy(), options, invocationStack); + + // Unless contained in the invocation stack (where we are free to impose the + // most optimal layout), we do not really impose a specific layout on the + // result. Contiguous is a good bet for now. + return getMemRefTypeWithStaticIdentityLayout(getResult().getType(), + getMemorySpaceAttr()); +} + +LogicalResult +StartTensorCopyOp::bufferize(RewriterBase &rewriter, + const BufferizationOptions &options) { + if (use_empty()) { + rewriter.eraseOp(*this); + return success(); + } + + FailureOr copyType = + bufferization::getBufferType(getCopy(), options); + if (failed(copyType)) + return failure(); + + FailureOr copyBuffer = getBuffer(rewriter, getCopy(), options); + if (failed(copyBuffer)) + return failure(); + + std::optional result = elidesAllocation(options); + if (!result) + return failure(); + + if (*result) { + Value token = rewriter.create(getLoc()); + replaceOpWithBufferizedValues(rewriter, getOperation(), + {*copyBuffer, token}); + return success(); + } + + FailureOr allocType = + bufferization::getBufferType(getResult(), options); + if (failed(allocType)) + return failure(); + + SmallVector copyBufferSizes = + memref::getMixedSizes(rewriter, getLoc(), *copyBuffer); + + // Compute the dynamic dimensions for the allocation. + SmallVector dynamicDims; + for (auto [index, shape, pad] : + llvm::enumerate(allocType->getShape(), getMixedHighPad())) { + if (!ShapedType::isDynamic(shape)) + continue; + + dynamicDims.push_back(affine::makeComposedAffineApply( + rewriter, getLoc(), + rewriter.getAffineDimExpr(0) + rewriter.getAffineDimExpr(1), + ArrayRef{copyBufferSizes[index], pad})); + } + + FailureOr alloc = options.createAlloc( + rewriter, getLoc(), llvm::cast(*allocType), + /*dynShape=*/dynamicDims); + if (failed(alloc)) + return failure(); + + // Zero out the entire buffer prior to overwriting it with the copied values. + // TODO: This could be optimized to only zero regions that won't be filled + // with the copied values at the cost of 2^rank transfers instead of two. + if (hasPadding() && !getUndefPadding()) + rewriter.create(getLoc(), *alloc); + + // Subview into the original memory without any padding. + // As we only add padding at the end of the dimensions, the offsets are always + // zero. + Value destination = rewriter.create( + getLoc(), *alloc, + /*offsets=*/ + SmallVector(allocType->getRank(), rewriter.getIndexAttr(0)), + copyBufferSizes, + /*strides=*/ + SmallVector(allocType->getRank(), + rewriter.getIndexAttr(1))); + Value token = + rewriter.create(getLoc(), *copyBuffer, destination); + + // Replace op. + replaceOpWithBufferizedValues(rewriter, getOperation(), {*alloc, token}); + return success(); +} + +//===----------------------------------------------------------------------===// +// WaitForTensorCopyOp +//===----------------------------------------------------------------------===// + +OpFoldResult WaitForTensorCopyOp::fold(FoldAdaptor adaptor) { + if (adaptor.getToken()) + return getTransferTensor(); + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// WaitForTensorCopyOp::BufferizableOpInterface +//===----------------------------------------------------------------------===// + +bool WaitForTensorCopyOp::mustBufferizeInPlace( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + return true; +} + +bool WaitForTensorCopyOp::bufferizesToMemoryRead( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + if (opOperand == getTransferTensorMutable()) + return false; + + if (opOperand == getCopyMutable()) + return true; + + llvm_unreachable("unknown operand"); +} + +bool WaitForTensorCopyOp::bufferizesToMemoryWrite( + OpOperand &opOperand, const bufferization::AnalysisState &) { + if (opOperand == getTransferTensorMutable()) + return true; + + if (opOperand == getCopyMutable()) + return false; + + llvm_unreachable("unknown operand"); +} + +AliasingValueList WaitForTensorCopyOp::getAliasingValues( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + if (opOperand == getCopyMutable()) + return {}; + + if (opOperand == getTransferTensorMutable()) + return {{getResult(), BufferRelation::Equivalent, /*isDefinite=*/true}}; + + llvm_unreachable("unknown operand"); +} + +LogicalResult +WaitForTensorCopyOp::bufferize(RewriterBase &rewriter, + const BufferizationOptions &options) { + FailureOr transferTensorBuffer = + getBuffer(rewriter, getTransferTensor(), options); + if (failed(transferTensorBuffer)) + return failure(); + + rewriter.create(getLoc(), getToken()); + replaceOpWithBufferizedValues(rewriter, getOperation(), + *transferTensorBuffer); + return success(); +} + +bool WaitForTensorCopyOp::isNotConflicting( + OpOperand *uRead, OpOperand *uWrite, + const bufferization::AnalysisState &state) { + if (*uRead == getCopyMutable() && *uWrite == getTransferTensorMutable()) + return true; + + return false; +} + +//===----------------------------------------------------------------------===// +// CompletedTokenOp +//===----------------------------------------------------------------------===// + +OpFoldResult CompletedTokenOp::fold(FoldAdaptor adaptor) { + return CompletedTokenAttr::get(getContext()); +} + +//===----------------------------------------------------------------------===// +// StartTransferOp +//===----------------------------------------------------------------------===// + +OpFoldResult StartTransferOp::fold(FoldAdaptor adaptor) { + if (getSource() != getDest()) + return nullptr; + + return CompletedTokenAttr::get(getContext()); +} + +//===----------------------------------------------------------------------===// +// WaitForTransfersOp +//===----------------------------------------------------------------------===// + +LogicalResult WaitForTransfersOp::fold(FoldAdaptor adaptor, + SmallVectorImpl &results) { + bool changed = false; + MutableOperandRange tokens = getTokensMutable(); + for (int i = tokens.size() - 1; i >= 0; i--) { + if (adaptor.getTokens()[i]) { + changed = true; + tokens.erase(i); + } + } + return success(changed); +} + +LogicalResult WaitForTransfersOp::canonicalize(WaitForTransfersOp op, + PatternRewriter &rewriter) { + if (!op.getTokens().empty()) + return failure(); + + rewriter.eraseOp(op); + return success(); +} diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.h b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.h new file mode 100644 index 0000000..a9f1f6f --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.h @@ -0,0 +1,14 @@ + +#pragma once + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "DMATypes.h" + +#define GET_OP_CLASSES +#include "Quidditch/Dialect/DMA/IR/DMAOps.h.inc" diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td new file mode 100644 index 0000000..b5f23fd --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td @@ -0,0 +1,206 @@ +#ifndef QUIDDITCH_DIALECT_SNITCH_DMAOPS +#define QUIDDITCH_DIALECT_SNITCH_DMAOPS + +include "Quidditch/Dialect/DMA/IR/DMADialect.td" +include "Quidditch/Dialect/DMA/IR/DMATypes.td" +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" +include "mlir/IR/CommonTypeConstraints.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class DMA_Op traits = []> : + Op; + +def DMA_StartTensorCopyOp : DMA_Op<"start_tensor_copy", + [Pure, AllRanksMatch<["copy", "result"]>, + DeclareOpInterfaceMethods]> { + + let description = [{ + Operation starting a copy of a tensor to another memory space, optionally + adding padding and returning it as a new tensor. + The contained values of the resulting tensor is in an unspecified state. + See `wait_for_tensor_copy` to transform the tensor value into a state + equal to `$copy`. + + The operation may optionally add padding at the end of each dimension of + the tensor. Zero is used as the padding value. + The dimensions of the result tensor are computed using + `dims(copy)[i] + high_pad[i]`. + + This operation is a noop if `$copy` is already in the given memory space, + no padding is added, and bufferization can elide the copy. + }]; + + let arguments = (ins AnyRankedTensor:$copy, + AnyAttr:$memory_space, + Variadic:$high_pad, + OptionalAttr:$static_high_pad, + UnitAttr:$undef_padding + ); + + let results = (outs + AnyRankedTensor:$result, + DMA_TokenType:$token + ); + + let assemblyFormat = [{ + `of` $copy `to` $memory_space + ( `pad` `with` (`undef` $undef_padding^) : (`zero`)? `by` + custom($high_pad, $static_high_pad)^)? + custom(ref($static_high_pad), type($copy), type($result)) + attr-dict + }]; + + let builders = [ + OpBuilder<(ins "mlir::Value":$copy, "mlir::Attribute":$memorySpace), [{ + build($_builder, $_state, copy.getType(), + $_builder.getType(), copy, memorySpace, + /*high_pad=*/mlir::ValueRange(), /*static_high_pad=*/nullptr); + }]> + ]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + private: + std::optional + elidesAllocation(const mlir::bufferization::BufferizationOptions &options = {}, + llvm::SmallVector *invocationStack = nullptr); + public: + + bool hasPadding() { + return static_cast(getStaticHighPadAttr()); + } + + llvm::SmallVector getMixedHighPad(); + }]; + + let hasFolder = 1; +} + +def DMA_WaitForTensorCopyOp : DMA_Op<"wait_for_tensor_copy", + [AllTypesMatch<["transfer_tensor", "result"]>, Pure, + DeclareOpInterfaceMethods]> { + + let description = [{ + Operation asserting that a previous `start_tensor_copy` operation has finished. + Unless `token` is the result of an `completed_token` operation, + `transfer_tensor` and `token` must at runtime be a token and tensor yielded + by a `start_tensor_copy` operation and `copy` the original tensor used in + `start_tensor_copy`. + + Once this operation returns, the returned tensor's values are guaranteed + equal to the `copy` operand and in the memory space specified in + `start_tensor_copy`. + + Note: The additional `copy` operand is given as it is effectively read by + this operation. + This additionally guarantees that the bufferization frame work does not + perform a write to the underlying buffer of `copy` while the transfer is + in progress. + }]; + + let arguments = (ins + AnyRankedTensor:$transfer_tensor, + DMA_TokenType:$token, + AnyRankedTensor:$copy + ); + + let results = (outs + AnyRankedTensor:$result + ); + + let assemblyFormat = [{ + `of` $copy `:` type($copy) `to` $transfer_tensor `using` $token `->` type($transfer_tensor) attr-dict + }]; + + let hasFolder = 1; +} + +def DMA_StartTransferOp : DMA_Op<"start_transfer", + [MemoryEffects<[MemWrite]>, SameOperandsElementType, SameOperandsShape]> { + + let description = [{ + Operation performing a DMA transfer from one MemRef to another. + The shapes (including dynamic ones at runtime) of both MemRefs must be + identical with different strides and offsets allowed. + + The DMA operation is likely (but not guaranteed) to run asynchronous and + its completion only guaranteed by executing the `wait_for_transfers` + operation with the token returned by this operation or a later one. + }]; + + let arguments = (ins + Arg, "source", [MemRead]>:$source, + Arg, "destination", [MemWrite]>:$dest + ); + + let results = (outs DMA_TokenType:$token); + + let assemblyFormat = [{ + `from` $source `:` type($source) `to` $dest `:` type($dest) attr-dict + }]; + + let hasFolder = 1; +} + +def DMA_StartZeroMemTransferOp : DMA_Op<"start_zero_mem_transfer", + [MemoryEffects<[MemWrite]>]> { + + let description = [{ + + }]; + + let arguments = (ins + Arg, "zeroed buffer", [MemWrite]>:$filled + ); + + let results = (outs DMA_TokenType:$token); + + let assemblyFormat = [{ + $filled `:` type($filled) attr-dict + }]; +} + +def DMA_WaitForTransfersOp : DMA_Op<"wait_for_transfers"> { + + let description = [{ + Operation awaiting for DMA transfers denoted by its tokens to be finished. + }]; + + let arguments = (ins + Variadic:$tokens + ); + + let assemblyFormat = [{ + ($tokens^ `:` type($tokens))? attr-dict + }]; + + let hasFolder = 1; + let hasCanonicalizeMethod = 1; +} + +def DMA_CompletedTokenOp + : DMA_Op<"completed_token", [Pure, ConstantLike]> { + + let description = [{ + Op returning a special value representing a completed DMA transfer. + Passing this token to `wait_for_transfers` will always return immediately. + }]; + + let results = (outs DMA_TokenType:$token); + + let assemblyFormat = [{ + attr-dict + }]; + + let hasFolder = 1; +} + +#endif diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.cpp b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.cpp new file mode 100644 index 0000000..7d8c729 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.cpp @@ -0,0 +1,11 @@ +#include "DMATypes.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +#include "DMADialect.h" + +#define GET_TYPEDEF_CLASSES +#include "Quidditch/Dialect/DMA/IR/DMATypes.cpp.inc" diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.h b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.h new file mode 100644 index 0000000..fde865f --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.h @@ -0,0 +1,7 @@ + +#pragma once + +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "Quidditch/Dialect/DMA/IR/DMATypes.h.inc" diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.td b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.td new file mode 100644 index 0000000..0288166 --- /dev/null +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMATypes.td @@ -0,0 +1,18 @@ +#ifndef QUIDDITCH_DIALECT_SNITCH_DMATYPES +#define QUIDDITCH_DIALECT_SNITCH_DMATYPES + +include "Quidditch/Dialect/DMA/IR/DMADialect.td" +include "mlir/IR/AttrTypeBase.td" + +class DMA_Type traits = []> : + TypeDef; + +def DMA_TokenType : DMA_Type<"Token"> { + let mnemonic = "token"; + + let description = [{ + Type representing a potentially active DMA transfer. + }]; +} + +#endif diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.td index 634cb21..35f27f5 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.td @@ -8,16 +8,6 @@ include "mlir/IR/AttrTypeBase.td" class QuidditchSnitch_Attr traits = []> : AttrDef; -def QuidditchSnitch_CompletedTokenAttr : QuidditchSnitch_Attr<"CompletedToken"> { - - let mnemonic = "completed_token"; - - let description = [{ - Attribute representing an instance of a `!quidditch_snitch.dma_token` - signaling a complete transfer. - }]; -} - def QuidditchSnitch_L1EncodingAttr : QuidditchSnitch_Attr<"L1Encoding"> { let mnemonic = "l1_encoding"; diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp index 2196493..bfe363f 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp @@ -60,13 +60,3 @@ void QuidditchSnitchDialect::initialize() { #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp.inc" >(); } - -Operation *QuidditchSnitchDialect::materializeConstant(OpBuilder &builder, - Attribute value, - Type type, - Location loc) { - if (isa(value)) - return builder.create(loc); - - return nullptr; -} diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td index 4292a6c..eadd62e 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td @@ -14,8 +14,8 @@ def QuidditchSnitch_Dialect : Dialect { ); let useDefaultAttributePrinterParser = 1; - let useDefaultTypePrinterParser = 1; - let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 0; + let hasConstantMaterializer = 0; } #endif diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp index aca847c..143cabf 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp @@ -396,416 +396,6 @@ void MicrokernelFenceOp::replaceWithNoop(RewriterBase &rewriter) { rewriter.eraseOp(*this); } -//===----------------------------------------------------------------------===// -// StartTensorCopyOp -//===----------------------------------------------------------------------===// - -LogicalResult StartTensorCopyOp::verify() { - if (getStaticHighPadAttr()) - if (getStaticHighPadAttr().size() != getCopy().getType().getRank()) - return emitOpError("expected padding number for every dimension"); - - unsigned numDynamicPads = llvm::count( - getStaticHighPad().value_or(std::nullopt), ShapedType::kDynamic); - if (numDynamicPads != getHighPad().size()) - return emitOpError("expected ") - << numDynamicPads << " dynamic padding values"; - - return success(); -} - -LogicalResult StartTensorCopyOp::fold(FoldAdaptor adaptor, - SmallVectorImpl &results) { - if (hasPadding()) { - // Remove noop padding. - if (llvm::all_of(getStaticHighPadAttr().asArrayRef(), - [](int64_t value) { return value == 0; })) { - removeStaticHighPadAttr(); - return success(); - } - - // Fold dynamic indices with constant values into the static list. - { - bool changed = false; - SmallVector padding = - llvm::to_vector(getStaticHighPadAttr().asArrayRef()); - unsigned dynamicIndex = 0; - for (int64_t &value : padding) { - if (!ShapedType::isDynamic(value)) - continue; - - if (auto integer = dyn_cast_or_null( - adaptor.getHighPad()[dynamicIndex])) { - value = integer.getValue().getZExtValue(); - getHighPadMutable().erase(dynamicIndex); - changed = true; - } else { - dynamicIndex++; - } - } - if (changed) { - setStaticHighPad(padding); - return success(); - } - } - } - - auto waitOp = getCopy().getDefiningOp(); - if (!waitOp) - return failure(); - auto copyOp = waitOp.getTransferTensor().getDefiningOp(); - if (!copyOp) - return failure(); - - if (hasPadding() && - (copyOp.getStaticHighPadAttr() != getStaticHighPadAttr() || - copyOp.getHighPad() != getHighPad())) - return failure(); - - results.emplace_back(waitOp); - results.emplace_back(CompletedTokenAttr::get(getContext())); - return success(); -} - -SmallVector StartTensorCopyOp::getMixedHighPad() { - Builder builder(getContext()); - if (!hasPadding()) - return SmallVector(getResult().getType().getRank(), - builder.getIndexAttr(0)); - - return getMixedValues(getStaticHighPadAttr().asArrayRef(), getHighPad(), - builder); -} - -//===----------------------------------------------------------------------===// -// StartTensorCopyOp::BufferizableOpInterface -//===----------------------------------------------------------------------===// - -/// Returns whether the allocation can be elided entirely. -/// Returns an empty optional if it was not possible to determine. -std::optional StartTensorCopyOp::elidesAllocation( - const bufferization::BufferizationOptions &options, - SmallVector *invocationStack) { - // Padding cannot be elided in general, even if the copied buffer is in L1. - if (hasPadding()) - return false; - - FailureOr copyType = - invocationStack - ? bufferization::getBufferType(getCopy(), options, *invocationStack) - : bufferization::getBufferType(getCopy(), options); - if (failed(copyType)) - return std::nullopt; - - return isa_and_nonnull(copyType->getMemorySpace()); -} - -bool StartTensorCopyOp::resultBufferizesToMemoryWrite( - OpResult opResult, const bufferization::AnalysisState &state) { - assert(opResult == getResult() && "no other result"); - - std::optional matches = elidesAllocation(state.getOptions()); - // Conservative answer. - if (!matches) - return true; - - // No copy is performed unless the address space does not match. - // Copy in this context implies that we are writing to the result. - return !*matches; -} - -bool StartTensorCopyOp::bufferizesToMemoryRead( - OpOperand &opOperand, const bufferization::AnalysisState &state) { - assert(opOperand == getCopyMutable() && "have only one operand"); - - std::optional result = elidesAllocation(state.getOptions()); - // Conservative answer. - if (!result) - return true; - - // We only read from the buffer if we are copying. - return !*result; -} - -bool StartTensorCopyOp::bufferizesToMemoryWrite( - OpOperand &opOperand, const bufferization::AnalysisState &) { - assert(opOperand == getCopyMutable() && "have only one operand"); - - // We do not write into the buffer we are copying ever. - return false; -} - -AliasingValueList StartTensorCopyOp::getAliasingValues( - OpOperand &opOperand, const bufferization::AnalysisState &state) { - assert(opOperand == getCopyMutable() && "have only one operand"); - - std::optional result = elidesAllocation(state.getOptions()); - if (!result) - // Assume the worst case. - return {{getResult(), BufferRelation::Equivalent, /*isDefinite=*/false}}; - - // Always a brand-new allocation unless the input buffer is already in L1 and - // we elide the copy, in which case operand and result alias. - if (*result) - return {{getResult(), BufferRelation::Equivalent, /*isDefinite=*/true}}; - - return {}; -} - -bool StartTensorCopyOp::bufferizesToAllocation(Value value) { - assert(value == getResult() && "have only one result"); - - if (elidesAllocation() == true) - return false; - - // True is the conservative reply, according to the docs. - return true; -} - -FailureOr -StartTensorCopyOp::getBufferType(Value value, - const BufferizationOptions &options, - SmallVector &invocationStack) { - assert(value == getResult() && "have only one result"); - - bool contained = llvm::is_contained(invocationStack, value); - if (!contained) - if (elidesAllocation(options, &invocationStack) == true) - return bufferization::getBufferType(getCopy(), options, invocationStack); - - // Unless contained in the invocation stack (where we are free to impose the - // most optimal layout), we do not really impose a specific layout on the - // result. Contiguous is a good bet for now. - return getMemRefTypeWithStaticIdentityLayout( - getResult().getType(), L1EncodingAttr::get(getContext())); -} - -LogicalResult -StartTensorCopyOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options) { - if (use_empty()) { - rewriter.eraseOp(*this); - return success(); - } - - FailureOr copyType = - bufferization::getBufferType(getCopy(), options); - if (failed(copyType)) - return failure(); - - FailureOr copyBuffer = getBuffer(rewriter, getCopy(), options); - if (failed(copyBuffer)) - return failure(); - - std::optional result = elidesAllocation(options); - if (!result) - return failure(); - - if (*result) { - Value token = rewriter.create(getLoc()); - replaceOpWithBufferizedValues(rewriter, getOperation(), - {*copyBuffer, token}); - return success(); - } - - FailureOr allocType = - bufferization::getBufferType(getResult(), options); - if (failed(allocType)) - return failure(); - - SmallVector copyBufferSizes = - memref::getMixedSizes(rewriter, getLoc(), *copyBuffer); - - // Compute the dynamic dimensions for the allocation. - SmallVector dynamicDims; - for (auto [index, shape, pad] : - llvm::enumerate(allocType->getShape(), getMixedHighPad())) { - if (!ShapedType::isDynamic(shape)) - continue; - - dynamicDims.push_back(affine::makeComposedAffineApply( - rewriter, getLoc(), - rewriter.getAffineDimExpr(0) + rewriter.getAffineDimExpr(1), - ArrayRef{copyBufferSizes[index], pad})); - } - - FailureOr alloc = options.createAlloc( - rewriter, getLoc(), llvm::cast(*allocType), - /*dynShape=*/dynamicDims); - if (failed(alloc)) - return failure(); - - // Zero out the entire buffer prior to overwriting it with the copied values. - // TODO: This could be optimized to only zero regions that won't be filled - // with the copied values at the cost of 2^rank transfers instead of two. - if (hasPadding() && !getUndefPadding()) - rewriter.create(getLoc(), *alloc); - - // Subview into the original memory without any padding. - // As we only add padding at the end of the dimensions, the offsets are always - // zero. - Value destination = rewriter.create( - getLoc(), *alloc, - /*offsets=*/ - SmallVector(allocType->getRank(), rewriter.getIndexAttr(0)), - copyBufferSizes, - /*strides=*/ - SmallVector(allocType->getRank(), - rewriter.getIndexAttr(1))); - Value token = - rewriter.create(getLoc(), *copyBuffer, destination); - - // Replace op. - replaceOpWithBufferizedValues(rewriter, getOperation(), {*alloc, token}); - return success(); -} - -//===----------------------------------------------------------------------===// -// WaitForTensorCopyOp -//===----------------------------------------------------------------------===// - -OpFoldResult WaitForTensorCopyOp::fold(FoldAdaptor adaptor) { - if (adaptor.getToken()) - return getTransferTensor(); - - return nullptr; -} - -//===----------------------------------------------------------------------===// -// WaitForTensorCopyOp::BufferizableOpInterface -//===----------------------------------------------------------------------===// - -bool WaitForTensorCopyOp::mustBufferizeInPlace( - OpOperand &opOperand, const bufferization::AnalysisState &state) { - return true; -} - -bool WaitForTensorCopyOp::bufferizesToMemoryRead( - OpOperand &opOperand, const bufferization::AnalysisState &state) { - if (opOperand == getTransferTensorMutable()) - return false; - - if (opOperand == getCopyMutable()) - return true; - - llvm_unreachable("unknown operand"); -} - -bool WaitForTensorCopyOp::bufferizesToMemoryWrite( - OpOperand &opOperand, const bufferization::AnalysisState &) { - if (opOperand == getTransferTensorMutable()) - return true; - - if (opOperand == getCopyMutable()) - return false; - - llvm_unreachable("unknown operand"); -} - -AliasingValueList WaitForTensorCopyOp::getAliasingValues( - OpOperand &opOperand, const bufferization::AnalysisState &state) { - if (opOperand == getCopyMutable()) - return {}; - - if (opOperand == getTransferTensorMutable()) - return {{getResult(), BufferRelation::Equivalent, /*isDefinite=*/true}}; - - llvm_unreachable("unknown operand"); -} - -LogicalResult -WaitForTensorCopyOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options) { - FailureOr transferTensorBuffer = - getBuffer(rewriter, getTransferTensor(), options); - if (failed(transferTensorBuffer)) - return failure(); - - rewriter.create(getLoc(), getToken()); - replaceOpWithBufferizedValues(rewriter, getOperation(), - *transferTensorBuffer); - return success(); -} - -bool WaitForTensorCopyOp::isNotConflicting( - OpOperand *uRead, OpOperand *uWrite, - const bufferization::AnalysisState &state) { - if (*uRead == getCopyMutable() && *uWrite == getTransferTensorMutable()) - return true; - - return false; -} - -//===----------------------------------------------------------------------===// -// CompletedTokenOp -//===----------------------------------------------------------------------===// - -OpFoldResult CompletedTokenOp::fold(FoldAdaptor adaptor) { - return CompletedTokenAttr::get(getContext()); -} - -//===----------------------------------------------------------------------===// -// StartDMATransferOp -//===----------------------------------------------------------------------===// - -OpFoldResult StartDMATransferOp::fold(FoldAdaptor adaptor) { - if (getSource() != getDest()) - return nullptr; - - return CompletedTokenAttr::get(getContext()); -} - -//===----------------------------------------------------------------------===// -// StartDMATransferOp::DMACoreSpecializationOpInterface -//===----------------------------------------------------------------------===// - -void StartDMATransferOp::replaceWithNoop(RewriterBase &rewriter) { - rewriter.replaceOpWithNewOp(*this); -} - -//===----------------------------------------------------------------------===// -// StartZeroMemTransferOp::DMACoreSpecializationOpInterface -//===----------------------------------------------------------------------===// - -void StartZeroMemTransferOp::replaceWithNoop(RewriterBase &rewriter) { - rewriter.replaceOpWithNewOp(*this); -} - -//===----------------------------------------------------------------------===// -// WaitForDMATransfersOp -//===----------------------------------------------------------------------===// - -LogicalResult -WaitForDMATransfersOp::fold(FoldAdaptor adaptor, - SmallVectorImpl &results) { - bool changed = false; - MutableOperandRange tokens = getTokensMutable(); - for (int i = tokens.size() - 1; i >= 0; i--) { - if (adaptor.getTokens()[i]) { - changed = true; - tokens.erase(i); - } - } - return success(changed); -} - -LogicalResult WaitForDMATransfersOp::canonicalize(WaitForDMATransfersOp op, - PatternRewriter &rewriter) { - if (!op.getTokens().empty()) - return failure(); - - rewriter.eraseOp(op); - return success(); -} - -//===----------------------------------------------------------------------===// -// WaitForDMATransfersOp::DMACoreSpecializationOpInterface -//===----------------------------------------------------------------------===// - -void WaitForDMATransfersOp::replaceWithNoop(RewriterBase &rewriter) { - rewriter.eraseOp(*this); -} - //===----------------------------------------------------------------------===// // ComputeCoreIndexOp::ComputeCoreSpecializationOpInterface //===----------------------------------------------------------------------===// diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td index 6c1b9a7..8612295 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td @@ -174,114 +174,6 @@ def QuidditchSnitch_MicrokernelFenceOp : QuidditchSnitch_Op<"microkernel_fence", }]; } -def QuidditchSnitch_StartTensorCopyOp : QuidditchSnitch_Op<"start_tensor_copy", - [Pure, AllRanksMatch<["copy", "result"]>, - DeclareOpInterfaceMethods]> { - - let description = [{ - Operation starting a copy of a tensor to L1 memory space, optionally adding - padding and returning it as a new tensor. - The contained values of the resulting tensor is in an unspecified state. - See `wait_for_tensor_copy` to transform the tensor value into a state - equal to `$copy`. - - The operation may optionally add padding at the end of each dimension of - the tensor. Zero is used as the padding value. - The dimensions of the result tensor are computed using - `dims(copy)[i] + high_pad[i]`. - - This operation is a noop if `$copy` is already in L1, no padding is added, - and bufferization can elide the copy. - }]; - - let arguments = (ins AnyRankedTensor:$copy, - Variadic:$high_pad, - OptionalAttr:$static_high_pad, - UnitAttr:$undef_padding - ); - - let results = (outs - AnyRankedTensor:$result, - QuidditchSnitch_DMATokenType:$token - ); - - let assemblyFormat = [{ - $copy `to` `L1` - ( `pad` `with` (`undef` $undef_padding^) : (`zero`)? `to` - custom($high_pad, $static_high_pad)^)? - `:` type($copy) `->` type($result) attr-dict - }]; - - let builders = [ - OpBuilder<(ins "mlir::Value":$copy), [{ - build($_builder, $_state, copy.getType(), - $_builder.getType(), copy, - /*high_pad=*/mlir::ValueRange(), /*static_high_pad=*/nullptr); - }]> - ]; - - let hasVerifier = 1; - - let extraClassDeclaration = [{ - private: - std::optional - elidesAllocation(const mlir::bufferization::BufferizationOptions &options = {}, - llvm::SmallVector *invocationStack = nullptr); - public: - - bool hasPadding() { - return static_cast(getStaticHighPadAttr()); - } - - llvm::SmallVector getMixedHighPad(); - }]; - - let hasFolder = 1; -} - -def QuidditchSnitch_WaitForTensorCopyOp : QuidditchSnitch_Op<"wait_for_tensor_copy", - [AllTypesMatch<["transfer_tensor", "result"]>, Pure, - DeclareOpInterfaceMethods]> { - - let description = [{ - Operation asserting that a previous `start_tensor_copy` operation has finished. - Unless `token` is the result of an `completed_token` operation, - `transfer_tensor` and `token` must at runtime be a token and tensor yielded - by a `start_tensor_copy` operation and `copy` the original tensor used in - `start_tensor_copy`. - - Once this operation returns, the returned tensor's values are guaranteed - equal to the `copy` operand and in L1 memory. - - Note: The additional `copy` operand is given as it is effectively read by - this operation. - This additionally guarantees that the bufferization frame work does not - perform a write to the underlying buffer of `copy` while the transfer is - in progress. - }]; - - let arguments = (ins - AnyRankedTensor:$transfer_tensor, - QuidditchSnitch_DMATokenType:$token, - AnyRankedTensor:$copy - ); - - let results = (outs - AnyRankedTensor:$result - ); - - let assemblyFormat = [{ - `of` $copy `:` type($copy) `to` $transfer_tensor `using` $token `->` type($transfer_tensor) attr-dict - }]; - - let hasFolder = 1; -} - def FlatI8MemRef : ConfinedType, [HasStaticShapePred, HasAnyRankOfPred<[1]>], "one-dimensional i8 MemRef of a static size">; @@ -294,107 +186,6 @@ def QuidditchSnitch_L1MemoryViewOp : QuidditchSnitch_Op<"l1_memory_view", }]; } -def QuidditchSnitch_StartDMATransferOp : QuidditchSnitch_Op<"start_dma_transfer", - [MemoryEffects<[MemWrite]>, SameOperandsElementType, SameOperandsShape, - QuidditchSnitch_DMACoreSpecializationOpInterface]> { - - let description = [{ - Operation performing a DMA transfer from one MemRef to another. - The shapes (including dynamic ones at runtime) of both MemRefs must be - identical with different strides and offsets allowed. - - The DMA operation is likely (but not guaranteed) to run asynchronous and - its completion only guaranteed by executing the `wait_for_dma_transfers` - operation with the token returned by this operation or a later one. - }]; - - let arguments = (ins - Arg, "source", [MemRead]>:$source, - Arg, "destination", [MemWrite]>:$dest - ); - - let results = (outs QuidditchSnitch_DMATokenType:$token); - - let assemblyFormat = [{ - `from` $source `:` type($source) `to` $dest `:` type($dest) attr-dict - }]; - - let hasFolder = 1; - - let extraClassDeclaration = [{ - void replaceWithNoop(mlir::RewriterBase& rewriter); - }]; -} - -def QuidditchSnitch_StartZeroMemTransferOp : QuidditchSnitch_Op<"start_zero_mem_transfer", - [MemoryEffects<[MemWrite]>, - QuidditchSnitch_DMACoreSpecializationOpInterface]> { - - let description = [{ - - }]; - - let arguments = (ins - Arg, "zeroed buffer", [MemWrite]>:$filled - ); - - let results = (outs QuidditchSnitch_DMATokenType:$token); - - let assemblyFormat = [{ - $filled `:` type($filled) attr-dict - }]; - - let extraClassDeclaration = [{ - void replaceWithNoop(mlir::RewriterBase& rewriter); - }]; -} - -def QuidditchSnitch_WaitForDMATransfersOp - : QuidditchSnitch_Op<"wait_for_dma_transfers", [ - QuidditchSnitch_DMACoreSpecializationOpInterface - ]> { - - let description = [{ - Operation awaiting for DMA transfers denoted by its tokens to be finished. - }]; - - let arguments = (ins - Variadic:$tokens - ); - - let assemblyFormat = [{ - ($tokens^ `:` type($tokens))? attr-dict - }]; - - let hasFolder = 1; - let hasCanonicalizeMethod = 1; - - let extraClassDeclaration = [{ - bool needsSynchronization() { - return true; - } - - void replaceWithNoop(mlir::RewriterBase& rewriter); - }]; -} - -def QuidditchSnitch_CompletedTokenOp - : QuidditchSnitch_Op<"completed_token", [Pure, ConstantLike]> { - - let description = [{ - Op returning a special value representing a completed DMA transfer. - Passing this token to `wait_for_dma_transfers` will always return immediately. - }]; - - let results = (outs QuidditchSnitch_DMATokenType:$token); - - let assemblyFormat = [{ - attr-dict - }]; - - let hasFolder = 1; -} - def QuidditchSnitch_BarrierOp : QuidditchSnitch_Op<"barrier"> { let assemblyFormat = [{ attr-dict diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td index 42a3bfc..05a006c 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td @@ -7,12 +7,4 @@ include "mlir/IR/AttrTypeBase.td" class QuidditchSnitch_Type traits = []> : TypeDef; -def QuidditchSnitch_DMATokenType : QuidditchSnitch_Type<"DMAToken"> { - let mnemonic = "dma_token"; - - let description = [{ - Type representing a potentially active DMA transfer. - }]; -} - #endif diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/CMakeLists.txt b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/CMakeLists.txt index aca8165..7833461 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/CMakeLists.txt +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/CMakeLists.txt @@ -23,6 +23,7 @@ iree_cc_library( "SpecializeDMACode.cpp" DEPS ::PassesIncGen + Quidditch::Dialect::DMA::IR::DMADialect Quidditch::Dialect::Snitch::IR::QuidditchSnitchDialect MLIRIR MLIRAffineDialect diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/Passes.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/Passes.td index 3224d91..015ae86 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/Passes.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/Passes.td @@ -18,6 +18,7 @@ def PromotePadsToL1Pass : Pass<"quidditch-promote-pads-to-l1"> { let dependentDialects = [ "quidditch::Snitch::QuidditchSnitchDialect", + "quidditch::dma::DMADialect", ]; } @@ -28,6 +29,7 @@ def PromoteOperandsToL1Pass : Pass<"quidditch-promote-operands-to-l1"> { let dependentDialects = [ "quidditch::Snitch::QuidditchSnitchDialect", + "quidditch::dma::DMADialect", ]; } @@ -88,6 +90,7 @@ def LowerForallOpPass : Pass<"quidditch-lower-forall-op"> { def PipelineCopyComputePass : Pass<"quidditch-pipeline-copy-compute"> { let dependentDialects = [ "quidditch::Snitch::QuidditchSnitchDialect", + "quidditch::dma::DMADialect", ]; } diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PipelineCopyCompute.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PipelineCopyCompute.cpp index cb49f0d..68369f1 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PipelineCopyCompute.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PipelineCopyCompute.cpp @@ -1,5 +1,7 @@ #include "Passes.h" +#include "Quidditch/Dialect/DMA/IR/DMADialect.h" +#include "Quidditch/Dialect/DMA/IR/DMAOps.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h" @@ -29,6 +31,7 @@ class PipelineCopyCompute using namespace mlir; using namespace mlir::iree_compiler; using namespace quidditch::Snitch; +using namespace quidditch::dma; /// Lifts an 'scf.for' op to a pipeline op with two stages. /// The body of the for loop gets placed in the second stage with all iter args diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PromoteToL1.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PromoteToL1.cpp index a99a932..e440f6c 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PromoteToL1.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PromoteToL1.cpp @@ -1,5 +1,7 @@ #include "Passes.h" +#include "Quidditch/Dialect/DMA/IR/DMADialect.h" +#include "Quidditch/Dialect/DMA/IR/DMAOps.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h" @@ -50,6 +52,7 @@ class PromotePadsToL1 using namespace mlir; using namespace quidditch::Snitch; +using namespace quidditch::dma; void PromoteOperandsToL1::runOnOperation() { // Copy all tensors used as operands to compute ops into L1 memory. @@ -62,8 +65,9 @@ void PromoteOperandsToL1::runOnOperation() { auto builder = OpBuilder(computeOp); for (OpOperand *use : nonL1Uses) { - auto copyOp = builder.create(computeOp.getLoc(), - /*copy=*/use->get()); + auto copyOp = builder.create( + computeOp.getLoc(), + /*copy=*/use->get(), builder.getAttr()); auto waitOp = builder.create( computeOp.getLoc(), copyOp.getResult(), copyOp.getToken(), /*copy=*/use->get()); @@ -81,8 +85,9 @@ void PromoteAllocsToL1::runOnOperation() { } OpBuilder builder(tensorOp); - auto copyOp = builder.create(tensorOp.getLoc(), - tensorOp.getCopy()); + auto copyOp = + builder.create(tensorOp.getLoc(), tensorOp.getCopy(), + builder.getAttr()); auto waitOp = builder.create( tensorOp.getLoc(), copyOp.getResult(), copyOp.getToken(), /*copy=*/tensorOp.getCopy()); @@ -112,9 +117,9 @@ void PromotePadsToL1::runOnOperation() { OpBuilder builder(padOp); auto copyOp = builder.create( - padOp.getLoc(), padOp.getType(), builder.getType(), - padOp.getSource(), padOp.getHigh(), padOp.getStaticHighAttr(), - undefPadding); + padOp.getLoc(), padOp.getType(), builder.getType(), + padOp.getSource(), builder.getAttr(), padOp.getHigh(), + padOp.getStaticHighAttr(), undefPadding); auto waitOp = builder.create( padOp.getLoc(), copyOp.getResult(), copyOp.getToken(), /*copy=*/padOp.getSource()); diff --git a/codegen/compiler/src/Quidditch/Target/CMakeLists.txt b/codegen/compiler/src/Quidditch/Target/CMakeLists.txt index 9f4cc7e..bdca133 100644 --- a/codegen/compiler/src/Quidditch/Target/CMakeLists.txt +++ b/codegen/compiler/src/Quidditch/Target/CMakeLists.txt @@ -32,6 +32,7 @@ iree_cc_library( DEPS ::PassesIncGen Quidditch::Conversion::ConvertSnitchToLLVM + Quidditch::Conversion::ConvertDMAToLLVM Quidditch::Dialect::Snitch::IR::QuidditchSnitchDialect MLIRFuncDialect MLIRIR @@ -48,6 +49,7 @@ iree_cc_library( ::Passes Quidditch::Conversion::ConvertToRISCV Quidditch::Dialect::Snitch::Transforms::Passes + Quidditch::Dialect::DMA::Extensions::DMACoreSpecializationOpInterfaceImpl IREELinalgTransformDialect LLVMAnalysis LLVMBitReader diff --git a/codegen/compiler/src/Quidditch/Target/ConvertToLLVM.cpp b/codegen/compiler/src/Quidditch/Target/ConvertToLLVM.cpp index bdeba3e..3a4138e 100644 --- a/codegen/compiler/src/Quidditch/Target/ConvertToLLVM.cpp +++ b/codegen/compiler/src/Quidditch/Target/ConvertToLLVM.cpp @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "Quidditch/Conversion/ConvertDMAToLLVM.h" #include "Quidditch/Conversion/ConvertSnitchToLLVM.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h" #include "iree/compiler/Codegen/LLVMCPU/DispatchABI.h" @@ -1036,6 +1037,7 @@ void ConvertToLLVMPass::runOnOperation() { populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns); populateVectorToLLVMConversionPatterns(typeConverter, patterns, false); populateSnitchToLLVMConversionPatterns(module, typeConverter, patterns); + populateDMAToLLVMConversionPatterns(module, typeConverter, patterns); HALDispatchABI abi(&typeConverter); // clang-format off diff --git a/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp b/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp index 5e57741..dc51630 100644 --- a/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp +++ b/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp @@ -25,6 +25,9 @@ #include "mlir/Transforms/Passes.h" #include "Quidditch/Conversion/Passes.h" +#include "Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.h" +#include "Quidditch/Dialect/DMA/IR/DMADialect.h" +#include "Quidditch/Dialect/DMA/IR/DMAOps.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h" #include "Quidditch/Dialect/Snitch/Transforms/Passes.h" @@ -129,9 +132,11 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend { void getDependentDialects(DialectRegistry ®istry) const override { mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); + quidditch::dma::registerDMACoreSpecializationOpInterface(registry); registry.insert(); + quidditch::Snitch::QuidditchSnitchDialect, + quidditch::dma::DMADialect>(); } void getDefaultExecutableTargets( @@ -208,14 +213,13 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend { return builder.create( loc, memRefType, dynamicSizes, builder.getI64IntegerAttr(alignment)); }; - BufferizationOptions::MemCpyFn memcpyFn = [](OpBuilder &builder, - Location loc, Value from, - Value to) { - Value token = - builder.create(loc, from, to); - builder.create(loc, token); - return success(); - }; + BufferizationOptions::MemCpyFn memcpyFn = + [](OpBuilder &builder, Location loc, Value from, Value to) { + Value token = + builder.create(loc, from, to); + builder.create(loc, token); + return success(); + }; FunctionLikeNest(modulePassManager) .addPass(createEliminateEmptyTensorsPass) diff --git a/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer.mlir b/codegen/tests/Conversion/ConvertDMAToLLVM/dma_transfer.mlir similarity index 79% rename from codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer.mlir rename to codegen/tests/Conversion/ConvertDMAToLLVM/dma_transfer.mlir index dc09379..259049e 100644 --- a/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer.mlir +++ b/codegen/tests/Conversion/ConvertDMAToLLVM/dma_transfer.mlir @@ -8,14 +8,14 @@ // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG1_PTR:[[:alnum:]]+]] -func.func private @test(%arg0 : memref, %arg1 : memref) -> !quidditch_snitch.dma_token { +func.func private @test(%arg0 : memref, %arg1 : memref) -> !dma.token { // CHECK: %[[ZERO:.*]] = llvm.mlir.zero // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ZERO]][%[[ARG0_SIZE]]] // CHECK: %[[SIZE:.*]] = llvm.ptrtoint %[[GEP]] // CHECK: %[[R:.*]] = llvm.call @snrt_dma_start_1d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[SIZE]]) - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref to %arg1 : memref + %0 = dma.start_transfer from %arg0 : memref to %arg1 : memref // CHECK: return %[[R]] - return %0 : !quidditch_snitch.dma_token + return %0 : !dma.token } // CHECK-LABEL: @test2 @@ -27,22 +27,22 @@ func.func private @test(%arg0 : memref, %arg1 : memref) -> !quiddi // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG1_ALIGNED_PTR:[[:alnum:]]+]] // CHECK-SAME: %[[ARG1_OFFSET:[[:alnum:]]+]] -func.func private @test2(%arg0 : memref, %arg1 : memref>) -> !quidditch_snitch.dma_token { +func.func private @test2(%arg0 : memref, %arg1 : memref>) -> !dma.token { // CHECK: %[[ARG1_PTR:.*]] = llvm.getelementptr %[[ARG1_ALIGNED_PTR]][%[[ARG1_OFFSET]]] // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ARG0_SIZE]]] // CHECK: %[[SIZE:.*]] = llvm.ptrtoint %[[GEP]] // CHECK: %[[R:.*]] = llvm.call @snrt_dma_start_1d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[SIZE]]) - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref to %arg1 : memref> + %0 = dma.start_transfer from %arg0 : memref to %arg1 : memref> // CHECK: llvm.call @snrt_dma_start_1d( - %1 = quidditch_snitch.start_dma_transfer from %arg1 : memref> to %arg0 : memref - return %0 : !quidditch_snitch.dma_token + %1 = dma.start_transfer from %arg1 : memref> to %arg0 : memref + return %0 : !dma.token } // CHECK-LABEL: @test3 -func.func private @test3(%arg0 : memref, %arg1 : memref>) -> !quidditch_snitch.dma_token { +func.func private @test3(%arg0 : memref, %arg1 : memref>) -> !dma.token { // CHECK: llvm.call @snrt_dma_start_1d( - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref to %arg1 : memref> - return %0 : !quidditch_snitch.dma_token + %0 = dma.start_transfer from %arg0 : memref to %arg1 : memref> + return %0 : !dma.token } // CHECK-LABEL: @dynamic_inner( @@ -61,15 +61,15 @@ func.func private @dynamic_inner(%subview_3 : memref<1x?xf64, strided<[161, 1], // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[BYTES]] - %12 = quidditch_snitch.start_dma_transfer from %subview_3 : memref<1x?xf64, strided<[161, 1], offset: ?>> to %subview_5 : memref<1x?xf64, strided<[81, 1]>> + %12 = dma.start_transfer from %subview_3 : memref<1x?xf64, strided<[161, 1], offset: ?>> to %subview_5 : memref<1x?xf64, strided<[81, 1]>> return } // CHECK-LABEL: @test4 -func.func private @test4(%arg0 : memref<1x4xf32>, %arg1 : memref<1x4xf32, strided<[40, 1], offset: ?>>) -> !quidditch_snitch.dma_token { +func.func private @test4(%arg0 : memref<1x4xf32>, %arg1 : memref<1x4xf32, strided<[40, 1], offset: ?>>) -> !dma.token { // CHECK: llvm.call @snrt_dma_start_1d( - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<1x4xf32> to %arg1 : memref<1x4xf32, strided<[40, 1], offset: ?>> - return %0 : !quidditch_snitch.dma_token + %0 = dma.start_transfer from %arg0 : memref<1x4xf32> to %arg1 : memref<1x4xf32, strided<[40, 1], offset: ?>> + return %0 : !dma.token } // CHECK-LABEL: @test5 @@ -86,7 +86,7 @@ func.func private @test4(%arg0 : memref<1x4xf32>, %arg1 : memref<1x4xf32, stride // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]] -func.func private @test5(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[8, 1], offset: 0>>) -> !quidditch_snitch.dma_token { +func.func private @test5(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[8, 1], offset: 0>>) -> !dma.token { // CHECK-DAG: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32) // CHECK-DAG: %[[FOUR_INDEX:.*]] = llvm.mlir.constant(4 : index) // CHECK-DAG: %[[TWO:.*]] = llvm.mlir.constant(2 : index) @@ -94,8 +94,8 @@ func.func private @test5(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, stride // CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[FOUR_INDEX]], %[[ELEMENT_WIDTH]] // CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]] // CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[TWO]]) - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[8, 1], offset: 0>> - return %0 : !quidditch_snitch.dma_token + %0 = dma.start_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[8, 1], offset: 0>> + return %0 : !dma.token } // CHECK-LABEL: @test6 @@ -116,7 +116,7 @@ func.func private @test5(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, stride // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG1_STRIDE0:[[:alnum:]]+]] // CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]] -func.func private @test6(%arg0 : memref<3x2x4xf32>, %arg1 : memref<3x2x4xf32, strided<[16, 8, 1], offset: 2>>) -> !quidditch_snitch.dma_token { +func.func private @test6(%arg0 : memref<3x2x4xf32>, %arg1 : memref<3x2x4xf32, strided<[16, 8, 1], offset: 2>>) -> !dma.token { // CHECK-DAG: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32) // CHECK-DAG: %[[ZERO32:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i32 @@ -147,9 +147,9 @@ func.func private @test6(%arg0 : memref<3x2x4xf32>, %arg1 : memref<3x2x4xf32, st // CHECK: %[[INV:.*]] = llvm.add %[[IV1]], %[[ONE]] // CHECK: llvm.br ^[[BB1]](%[[INV]], %[[RES]] - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<3x2x4xf32> to %arg1 : memref<3x2x4xf32, strided<[16, 8, 1], offset: 2>> + %0 = dma.start_transfer from %arg0 : memref<3x2x4xf32> to %arg1 : memref<3x2x4xf32, strided<[16, 8, 1], offset: 2>> // CHECK: return %[[IV2]] - return %0 : !quidditch_snitch.dma_token + return %0 : !dma.token } @@ -167,7 +167,7 @@ func.func private @test6(%arg0 : memref<3x2x4xf32>, %arg1 : memref<3x2x4xf32, st // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]] -func.func private @dynamic_strides(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[?, 1]>>) -> !quidditch_snitch.dma_token { +func.func private @dynamic_strides(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[?, 1]>>) -> !dma.token { // CHECK-DAG: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32) // CHECK-DAG: %[[FOUR:.*]] = llvm.mlir.constant(4 : index) // CHECK-DAG: %[[TWO:.*]] = llvm.mlir.constant(2 : index) @@ -175,8 +175,8 @@ func.func private @dynamic_strides(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf // CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[FOUR]], %[[ELEMENT_WIDTH]] // CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]] // CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[TWO]]) - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[?, 1]>> - return %0 : !quidditch_snitch.dma_token + %0 = dma.start_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[?, 1]>> + return %0 : !dma.token } // CHECK-LABEL: @contigious_dynamic_inner @@ -193,12 +193,12 @@ func.func private @dynamic_strides(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]] -func.func private @contigious_dynamic_inner(%arg0 : memref, %arg1 : memref>) -> !quidditch_snitch.dma_token { +func.func private @contigious_dynamic_inner(%arg0 : memref, %arg1 : memref>) -> !dma.token { // CHECK: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32) // CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[ARG0_STRIDE_0]], %[[ELEMENT_WIDTH]] // CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[ARG0_STRIDE_N]], %[[ELEMENT_WIDTH]] // CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]] // CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[ARG0_SIZE]]) - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref to %arg1 : memref> - return %0 : !quidditch_snitch.dma_token + %0 = dma.start_transfer from %arg0 : memref to %arg1 : memref> + return %0 : !dma.token } diff --git a/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_wait.mlir b/codegen/tests/Conversion/ConvertDMAToLLVM/dma_wait.mlir similarity index 79% rename from codegen/tests/Conversion/ConvertSnitchToLLVM/dma_wait.mlir rename to codegen/tests/Conversion/ConvertDMAToLLVM/dma_wait.mlir index 62e4128..47ad305 100644 --- a/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_wait.mlir +++ b/codegen/tests/Conversion/ConvertDMAToLLVM/dma_wait.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: @test // CHECK-SAME: %[[ARG0:[[:alnum:]]+]] -func.func private @test(%arg0 : !quidditch_snitch.dma_token) { +func.func private @test(%arg0 : !dma.token) { // CHECK: llvm.br ^[[BODY:[[:alnum:]]+]] // CHECK: ^[[BODY]]: // CHECK-NEXT: %[[ID:.*]] = llvm.inline_asm has_side_effects ".insn r 0x2b, 0, 0b100, $0, zero, zero @@ -11,7 +11,7 @@ func.func private @test(%arg0 : !quidditch_snitch.dma_token) { // CHECK: %[[COND:.*]] = llvm.icmp "ult" %[[ID]], %[[ARG0]] // CHECK: llvm.cond_br %[[COND]], ^[[BODY]], ^[[CONT:[[:alnum:]]+]] // CHECK: ^[[CONT]]: - quidditch_snitch.wait_for_dma_transfers %arg0 : !quidditch_snitch.dma_token + dma.wait_for_transfers %arg0 : !dma.token // CHECK-NEXT: llvm.return return } diff --git a/codegen/tests/Conversion/ConvertSnitchToLLVM/zero_mem_transfer.mlir b/codegen/tests/Conversion/ConvertDMAToLLVM/zero_mem_transfer.mlir similarity index 88% rename from codegen/tests/Conversion/ConvertSnitchToLLVM/zero_mem_transfer.mlir rename to codegen/tests/Conversion/ConvertDMAToLLVM/zero_mem_transfer.mlir index 5dcadb8..e0553a2 100644 --- a/codegen/tests/Conversion/ConvertSnitchToLLVM/zero_mem_transfer.mlir +++ b/codegen/tests/Conversion/ConvertDMAToLLVM/zero_mem_transfer.mlir @@ -5,7 +5,7 @@ // CHECK-SAME: %[[PTR:[[:alnum:]]+]] // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[DIM0:[[:alnum:]]+]] -func.func private @test(%arg0 : memref) -> !quidditch_snitch.dma_token { +func.func private @test(%arg0 : memref) -> !dma.token { // CHECK-DAG: %[[NULL:.*]] = llvm.mlir.zero // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[NULL]][%[[DIM0]]] @@ -16,9 +16,9 @@ func.func private @test(%arg0 : memref) -> !quidditch_snitch.dma_token { // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[PTR]][%[[OFFSET]]] // CHECK: %[[REM:.*]] = llvm.urem %[[SIZE]], %[[ZERO_MEM_SIZE]] // CHECK: %[[TOKEN:.*]] = llvm.call @snrt_dma_start_1d(%[[GEP]], %[[ZERO_MEM]], %[[REM]]) - %0 = quidditch_snitch.start_zero_mem_transfer %arg0 : memref + %0 = dma.start_zero_mem_transfer %arg0 : memref // CHECK: return %[[TOKEN]] - return %0 : !quidditch_snitch.dma_token + return %0 : !dma.token } // CHECK-LABEL: @test1( @@ -30,7 +30,7 @@ func.func private @test(%arg0 : memref) -> !quidditch_snitch.dma_token { // CHECK-SAME: %[[DIM2:[[:alnum:]]+]] // CHECK-SAME: %[[STRIDE0:[[:alnum:]]+]] // CHECK-SAME: %[[STRIDE1:[[:alnum:]]+]] -func.func private @test1(%arg0 : memref>) -> !quidditch_snitch.dma_token { +func.func private @test1(%arg0 : memref>) -> !dma.token { // CHECK-DAG: %[[ZERO_INDEX:.*]] = llvm.mlir.constant(0 : index) // CHECK-DAG: %[[ZERO_I32:.*]] = llvm.mlir.constant(0 : i32) // CHECK-DAG: %[[ONE:.*]] = llvm.mlir.constant(1 : @@ -67,7 +67,7 @@ func.func private @test1(%arg0 : memref>) -> !quid // CHECK: llvm.br ^[[LOOP0]](%[[INC0]], %[[TOKEN1]] : // CHECK: ^[[EXIT0]]: - %0 = quidditch_snitch.start_zero_mem_transfer %arg0 : memref> + %0 = dma.start_zero_mem_transfer %arg0 : memref> // CHECK: return %[[TOKEN0]] - return %0 : !quidditch_snitch.dma_token + return %0 : !dma.token } diff --git a/codegen/tests/Conversion/ConvertSnitchToLLVM/completed_token.mlir b/codegen/tests/Conversion/ConvertSnitchToLLVM/completed_token.mlir index 5008ba2..45617b9 100644 --- a/codegen/tests/Conversion/ConvertSnitchToLLVM/completed_token.mlir +++ b/codegen/tests/Conversion/ConvertSnitchToLLVM/completed_token.mlir @@ -1,9 +1,9 @@ // RUN: quidditch-opt %s --quidditch-convert-to-llvm | FileCheck %s // CHECK-LABEL: @test -func.func private @test() -> !quidditch_snitch.dma_token { +func.func private @test() -> !dma.token { // CHECK: %[[T:.*]] = llvm.mlir.constant(0 : {{.*}}) // CHECK: return %[[T]] - %0 = quidditch_snitch.completed_token - return %0 : !quidditch_snitch.dma_token + %0 = dma.completed_token + return %0 : !dma.token } diff --git a/codegen/tests/Dialect/DMA/IR/bufferization.mlir b/codegen/tests/Dialect/DMA/IR/bufferization.mlir new file mode 100644 index 0000000..c6843fc --- /dev/null +++ b/codegen/tests/Dialect/DMA/IR/bufferization.mlir @@ -0,0 +1,125 @@ +// RUN: quidditch-opt %s --one-shot-bufferize | FileCheck %s + +// CHECK: #[[$MAP2:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> + +// CHECK: func @copy_l1_buffer( +func.func @copy_l1_buffer(%arg0 : tensor<32xf32>) -> (tensor<32xf32>, !dma.token) { + // CHECK: %[[ARG0:.*]] = bufferization.to_memref + + // CHECK: %[[ALLOC:.*]] = memref.alloc() + // CHECK-SAME: : memref<32xf32, #quidditch_snitch.l1_encoding> + // CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]] + // CHECK-SAME: to memref<32xf32, strided<[1]>, #quidditch_snitch.l1_encoding> + // CHECK: %[[TOKEN:.*]] = dma.start_transfer from %[[ARG0]] + // CHECK-SAME: to %[[SUBVIEW]] + // CHECK: %[[R:.*]] = bufferization.to_tensor %[[ALLOC]] + %r, %token = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding -> tensor<32xf32> + // CHECK: return %[[R]], %[[TOKEN]] + return %r, %token : tensor<32xf32>, !dma.token +} + +// CHECK: func @copy_l1_buffer_elided( +func.func @copy_l1_buffer_elided(%arg0 : tensor<32xf32>) -> tensor<32xf32> { + // CHECK: memref.alloc() + // CHECK-NOT: memref.alloc() + %r:2 = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding -> tensor<32xf32> + %r2 = dma.wait_for_tensor_copy of %arg0 : tensor<32xf32> to %r#0 using %r#1 -> tensor<32xf32> + %r3:2 = dma.start_tensor_copy of %r2 to #quidditch_snitch.l1_encoding -> tensor<32xf32> + %r4 = dma.wait_for_tensor_copy of %r2 : tensor<32xf32> to %r3#0 using %r3#1 -> tensor<32xf32> + // CHECK: return + return %r4 : tensor<32xf32> +} + +// CHECK: func @copy_l1_buffer_alloca_elided( +func.func @copy_l1_buffer_alloca_elided() -> tensor<32xf32> { + // CHECK: memref.alloc() + // CHECK-NOT: memref.alloc() + %r = bufferization.alloc_tensor() {memory_space = #quidditch_snitch.l1_encoding} : tensor<32xf32> + %r2:2 = dma.start_tensor_copy of %r to #quidditch_snitch.l1_encoding : tensor<32xf32> -> tensor<32xf32> + // CHECK: return + return %r2#0 : tensor<32xf32> +} + +// CHECK: func @scf_for_copy_l1_buffer( +func.func @scf_for_copy_l1_buffer() -> tensor<32xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: %[[MEMREF:.*]] = memref.alloc + %r = bufferization.alloc_tensor() {memory_space = #quidditch_snitch.l1_encoding} : tensor<32xf32> + %r2:2 = dma.start_tensor_copy of %r to #quidditch_snitch.l1_encoding : tensor<32xf32> -> tensor<32xf32> + // CHECK-NEXT: dma.completed_token + // CHECK-NEXT: %[[R:.*]] = scf.for + // CHECK-SAME: iter_args(%[[ITER:.*]] = %[[MEMREF]]) + // CHECK-NEXT: dma.completed_token + // CHECK-NEXT: scf.yield %[[ITER]] + // CHECK: bufferization.to_tensor %[[R]] + %r3 = scf.for %i = %c0 to %c1 step %c1 iter_args(%iter = %r2#0) -> (tensor<32xf32>) { + %r4:2 = dma.start_tensor_copy of %iter to #quidditch_snitch.l1_encoding -> tensor<32xf32> + scf.yield %r4#0 : tensor<32xf32> + } + return %r3 : tensor<32xf32> +} + +// CHECK: func @copy_l1_buffer_dynamic_dims( +func.func @copy_l1_buffer_dynamic_dims(%arg0 : tensor) -> tensor { + // CHECK: %[[ARG0:.*]] = bufferization.to_memref + // CHECK: %[[ZERO:.*]] = arith.constant 0 + // CHECK: %[[DIM_IN:.*]] = memref.dim %[[ARG0]], %[[ZERO]] + // CHECK: %[[DIM:.*]] = affine.apply #{{.*}}()[%[[DIM_IN]]] + // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) + // CHECK-SAME: : memref + // CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]] + // CHECK-SAME: to memref, #quidditch_snitch.l1_encoding> + // CHECK: dma.start_transfer from %[[ARG0]] + // CHECK-SAME: to %[[SUBVIEW]] + // CHECK: %[[R:.*]] = bufferization.to_tensor %[[ALLOC]] + %r:2 = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding -> tensor + // CHECK: return %[[R]] + return %r#0 : tensor +} + +// CHECK-LABEL: @tensor_copy_pad +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +// CHECK-SAME: %[[PAD0:[[:alnum:]]+]] +// CHECK-SAME: %[[PAD1:[[:alnum:]]+]] +func.func @tensor_copy_pad(%arg0 : tensor, %pad0 : index, %pad1 : index) -> (tensor, !dma.token) { + // CHECK: %[[COPY:.*]] = bufferization.to_memref %[[ARG0]] + // CHECK: %[[ZERO:.*]] = arith.constant 0 + // CHECK: %[[DIM0:.*]] = memref.dim %[[COPY]], %[[ZERO]] + // CHECK: %[[ONE:.*]] = arith.constant 1 + // CHECK: %[[DIM1:.*]] = memref.dim %[[COPY]], %[[ONE]] + // CHECK: %[[NEW_DIM0:.*]] = affine.apply #[[$MAP2]]()[%[[DIM0]], %[[PAD0]]] + // CHECK: %[[NEW_DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[DIM1]], %[[PAD1]]] + // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[NEW_DIM0]], %[[NEW_DIM1]]) + // CHECK: start_zero_mem_transfer %[[ALLOC]] + // CHECK: %[[UNPADDED:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[DIM0]], %[[DIM1]]] [1, 1] + // CHECK: %[[TOKEN:.*]] = dma.start_transfer from %[[COPY]] + // CHECK-SAME: to %[[UNPADDED]] + %r, %t = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding pad with zero by [%pad0, %pad1] : tensor -> tensor + // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] + // CHECK: return %[[TENSOR]], %[[TOKEN]] + return %r, %t : tensor, !dma.token +} + +// CHECK-LABEL: @tensor_copy_pad_undef +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +// CHECK-SAME: %[[PAD0:[[:alnum:]]+]] +// CHECK-SAME: %[[PAD1:[[:alnum:]]+]] +func.func @tensor_copy_pad_undef(%arg0 : tensor, %pad0 : index, %pad1 : index) -> (tensor, !dma.token) { + // CHECK: %[[COPY:.*]] = bufferization.to_memref %[[ARG0]] + // CHECK: %[[ZERO:.*]] = arith.constant 0 + // CHECK: %[[DIM0:.*]] = memref.dim %[[COPY]], %[[ZERO]] + // CHECK: %[[ONE:.*]] = arith.constant 1 + // CHECK: %[[DIM1:.*]] = memref.dim %[[COPY]], %[[ONE]] + // CHECK: %[[NEW_DIM0:.*]] = affine.apply #[[$MAP2]]()[%[[DIM0]], %[[PAD0]]] + // CHECK: %[[NEW_DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[DIM1]], %[[PAD1]]] + // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[NEW_DIM0]], %[[NEW_DIM1]]) + // CHECK-NOT: start_zero_mem_transfer + // CHECK: %[[UNPADDED:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[DIM0]], %[[DIM1]]] [1, 1] + // CHECK-NEXT: %[[TOKEN:.*]] = dma.start_transfer from %[[COPY]] + // CHECK-SAME: to %[[UNPADDED]] + %r, %t = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding pad with undef by [%pad0, %pad1] : tensor -> tensor + // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] + // CHECK: return %[[TENSOR]], %[[TOKEN]] + return %r, %t : tensor, !dma.token +} diff --git a/codegen/tests/Dialect/DMA/IR/canonicalization.mlir b/codegen/tests/Dialect/DMA/IR/canonicalization.mlir new file mode 100644 index 0000000..60998f0 --- /dev/null +++ b/codegen/tests/Dialect/DMA/IR/canonicalization.mlir @@ -0,0 +1,111 @@ +// RUN: quidditch-opt %s --canonicalize --split-input-file --allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: @wait_gets_removed +func.func @wait_gets_removed() { + // CHECK-NEXT: return + %0 = dma.completed_token + dma.wait_for_transfers %0 : !dma.token + return +} + +// CHECK-LABEL: @noop_transfer +func.func @noop_transfer(%arg0 : memref) -> !dma.token { + // CHECK-NEXT: %[[R:.*]] = dma.completed_token + // CHECK-NEXT: return %[[R]] + %0 = dma.start_transfer from %arg0 : memref to %arg0 : memref + return %0 : !dma.token +} + +// CHECK-LABEL: @tensor_wait_gets_removed +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +// CHECK-SAME: %[[ARG1:[[:alnum:]]+]] +func.func @tensor_wait_gets_removed(%arg0 : tensor, %arg1 : tensor) -> tensor { + // CHECK-NEXT: return %[[ARG1]] + %t = dma.completed_token + %0 = dma.wait_for_tensor_copy of %arg0 : tensor to %arg1 using %t -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @tensor_noop_transfer +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +func.func @tensor_noop_transfer(%arg0 : tensor) -> (tensor, !dma.token) { + // CHECK: %[[T2:.*]] = dma.completed_token + // CHECK: %[[R:.*]], %[[T:.*]] = dma.start_tensor_copy of %[[ARG0]] + %r, %t = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding -> tensor + // CHECK: %[[R2:.*]] = dma.wait_for_tensor_copy of %[[ARG0]] + // CHECK-SAME: to %[[R]] using %[[T]] + %0 = dma.wait_for_tensor_copy of %arg0 : tensor to %r using %t -> tensor + + // CHECK-NOT: wait_for_tensor_copy + %r2, %t2 = dma.start_tensor_copy of %0 to #quidditch_snitch.l1_encoding -> tensor + + // CHECK: return %[[R2]], %[[T2]] + return %r2, %t2 : tensor, !dma.token +} + +// CHECK-LABEL: @tensor_noop_pad +func.func @tensor_noop_pad(%arg0 : tensor) -> (tensor, !dma.token) { + // CHECK: %[[R:.*]], %[[T:.*]] = dma.start_tensor_copy + // CHECK-NOT: pad with + %r, %t = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding pad with zero by [0] : tensor -> tensor + // CHECK-NEXT: return %[[R]], %[[T]] + return %r, %t : tensor, !dma.token +} + +// CHECK-LABEL: @tensor_pad_constant +func.func @tensor_pad_constant(%arg0 : tensor) -> (tensor, !dma.token) { + %zero = arith.constant 0 : index + // CHECK: %[[R:.*]], %[[T:.*]] = dma.start_tensor_copy + // CHECK-NOT: pad with + %r, %t = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding pad with zero by [%zero] : tensor -> tensor + // CHECK-NEXT: return %[[R]], %[[T]] + return %r, %t : tensor, !dma.token +} + +// CHECK-LABEL: @tensor_noop_transfer_pad +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +func.func @tensor_noop_transfer_pad(%arg0 : tensor) -> (tensor, !dma.token) { + // CHECK: %[[T2:.*]] = dma.completed_token + // CHECK: %[[R:.*]], %[[T:.*]] = dma.start_tensor_copy of %[[ARG0]] + %r, %t = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding pad with zero by [1] : tensor -> tensor + // CHECK: %[[R2:.*]] = dma.wait_for_tensor_copy of %[[ARG0]] + // CHECK-SAME: to %[[R]] using %[[T]] + %0 = dma.wait_for_tensor_copy of %arg0 : tensor to %r using %t -> tensor + + // CHECK-NOT: wait_for_tensor_copy + %r2, %t2 = dma.start_tensor_copy of %0 to #quidditch_snitch.l1_encoding -> tensor + + // CHECK: return %[[R2]], %[[T2]] + return %r2, %t2 : tensor, !dma.token +} + +// CHECK-LABEL: @tensor_noop_transfer_pad_neg +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +func.func @tensor_noop_transfer_pad_neg(%arg0 : tensor) -> (tensor, !dma.token) { + // CHECK: start_tensor_copy + // CHECK: wait_for_tensor_copy + // CHECK: %[[R:.*]], %[[T:.*]] = dma.start_tensor_copy + // CHECK: return %[[R]], %[[T]] + + %r, %t = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding -> tensor + %0 = dma.wait_for_tensor_copy of %arg0 : tensor to %r using %t -> tensor + %r2, %t2 = dma.start_tensor_copy of %0 to #quidditch_snitch.l1_encoding pad with zero by [1] : tensor -> tensor + return %r2, %t2 : tensor, !dma.token +} + +// CHECK-LABEL: @tensor_noop_transfer_same_padding +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +func.func @tensor_noop_transfer_same_padding(%arg0 : tensor) -> (tensor, !dma.token) { + // CHECK: %[[T2:.*]] = dma.completed_token + // CHECK: %[[R:.*]], %[[T:.*]] = dma.start_tensor_copy of %[[ARG0]] + %r, %t = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding pad with zero by [1] : tensor -> tensor + // CHECK: %[[R2:.*]] = dma.wait_for_tensor_copy of %[[ARG0]] + // CHECK-SAME: to %[[R]] using %[[T]] + %0 = dma.wait_for_tensor_copy of %arg0 : tensor to %r using %t -> tensor + + // CHECK-NOT: wait_for_tensor_copy + %r2, %t2 = dma.start_tensor_copy of %0 to #quidditch_snitch.l1_encoding pad with zero by [1] : tensor -> tensor + + // CHECK: return %[[R2]], %[[T2]] + return %r2, %t2 : tensor, !dma.token +} diff --git a/codegen/tests/Dialect/DMA/IR/roundtrip.mlir b/codegen/tests/Dialect/DMA/IR/roundtrip.mlir new file mode 100644 index 0000000..de6d37b --- /dev/null +++ b/codegen/tests/Dialect/DMA/IR/roundtrip.mlir @@ -0,0 +1,11 @@ +// RUN: quidditch-opt %s --verify-roundtrip + +func.func @test(%arg0 : memref) { + dma.wait_for_transfers + return +} + +func.func @test3(%arg0 : tensor) -> (tensor, !dma.token) { + %0:2 = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding -> tensor + return %0#0, %0#1 : tensor, !dma.token +} diff --git a/codegen/tests/Dialect/Snitch/IR/bufferization.mlir b/codegen/tests/Dialect/Snitch/IR/bufferization.mlir index 97d07c3..97fce22 100644 --- a/codegen/tests/Dialect/Snitch/IR/bufferization.mlir +++ b/codegen/tests/Dialect/Snitch/IR/bufferization.mlir @@ -1,83 +1,5 @@ // RUN: quidditch-opt %s --one-shot-bufferize | FileCheck %s -// CHECK: #[[$MAP2:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> - -// CHECK: func @copy_l1_buffer( -func.func @copy_l1_buffer(%arg0 : tensor<32xf32>) -> (tensor<32xf32>, !quidditch_snitch.dma_token) { - // CHECK: %[[ARG0:.*]] = bufferization.to_memref - - // CHECK: %[[ALLOC:.*]] = memref.alloc() - // CHECK-SAME: : memref<32xf32, #quidditch_snitch.l1_encoding> - // CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]] - // CHECK-SAME: to memref<32xf32, strided<[1]>, #quidditch_snitch.l1_encoding> - // CHECK: %[[TOKEN:.*]] = quidditch_snitch.start_dma_transfer from %[[ARG0]] - // CHECK-SAME: to %[[SUBVIEW]] - // CHECK: %[[R:.*]] = bufferization.to_tensor %[[ALLOC]] - %r, %token = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<32xf32> -> tensor<32xf32> - // CHECK: return %[[R]], %[[TOKEN]] - return %r, %token : tensor<32xf32>, !quidditch_snitch.dma_token -} - -// CHECK: func @copy_l1_buffer_elided( -func.func @copy_l1_buffer_elided(%arg0 : tensor<32xf32>) -> tensor<32xf32> { - // CHECK: memref.alloc() - // CHECK-NOT: memref.alloc() - %r:2 = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<32xf32> -> tensor<32xf32> - %r2 = quidditch_snitch.wait_for_tensor_copy of %arg0 : tensor<32xf32> to %r#0 using %r#1 -> tensor<32xf32> - %r3:2 = quidditch_snitch.start_tensor_copy %r2 to L1 : tensor<32xf32> -> tensor<32xf32> - %r4 = quidditch_snitch.wait_for_tensor_copy of %r2 : tensor<32xf32> to %r3#0 using %r3#1 -> tensor<32xf32> - // CHECK: return - return %r4 : tensor<32xf32> -} - -// CHECK: func @copy_l1_buffer_alloca_elided( -func.func @copy_l1_buffer_alloca_elided() -> tensor<32xf32> { - // CHECK: memref.alloc() - // CHECK-NOT: memref.alloc() - %r = bufferization.alloc_tensor() {memory_space = #quidditch_snitch.l1_encoding} : tensor<32xf32> - %r2:2 = quidditch_snitch.start_tensor_copy %r to L1 : tensor<32xf32> -> tensor<32xf32> - // CHECK: return - return %r2#0 : tensor<32xf32> -} - -// CHECK: func @scf_for_copy_l1_buffer( -func.func @scf_for_copy_l1_buffer() -> tensor<32xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - // CHECK: %[[MEMREF:.*]] = memref.alloc - %r = bufferization.alloc_tensor() {memory_space = #quidditch_snitch.l1_encoding} : tensor<32xf32> - %r2:2 = quidditch_snitch.start_tensor_copy %r to L1 : tensor<32xf32> -> tensor<32xf32> - // CHECK-NEXT: quidditch_snitch.completed_token - // CHECK-NEXT: %[[R:.*]] = scf.for - // CHECK-SAME: iter_args(%[[ITER:.*]] = %[[MEMREF]]) - // CHECK-NEXT: quidditch_snitch.completed_token - // CHECK-NEXT: scf.yield %[[ITER]] - // CHECK: bufferization.to_tensor %[[R]] - %r3 = scf.for %i = %c0 to %c1 step %c1 iter_args(%iter = %r2#0) -> (tensor<32xf32>) { - %r4:2 = quidditch_snitch.start_tensor_copy %iter to L1 : tensor<32xf32> -> tensor<32xf32> - scf.yield %r4#0 : tensor<32xf32> - } - return %r3 : tensor<32xf32> -} - -// CHECK: func @copy_l1_buffer_dynamic_dims( -func.func @copy_l1_buffer_dynamic_dims(%arg0 : tensor) -> tensor { - // CHECK: %[[ARG0:.*]] = bufferization.to_memref - // CHECK: %[[ZERO:.*]] = arith.constant 0 - // CHECK: %[[DIM_IN:.*]] = memref.dim %[[ARG0]], %[[ZERO]] - // CHECK: %[[DIM:.*]] = affine.apply #{{.*}}()[%[[DIM_IN]]] - // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) - // CHECK-SAME: : memref - // CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]] - // CHECK-SAME: to memref, #quidditch_snitch.l1_encoding> - // CHECK: quidditch_snitch.start_dma_transfer from %[[ARG0]] - // CHECK-SAME: to %[[SUBVIEW]] - // CHECK: %[[R:.*]] = bufferization.to_tensor %[[ALLOC]] - %r:2 = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor -> tensor - // CHECK: return %[[R]] - return %r#0 : tensor -} - // CHECK-LABEL: @pipeline_op( func.func @pipeline_op(%arg0_dim : index) -> tensor { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 @@ -142,49 +64,3 @@ func.func @sync_tensor() -> tensor<32xf32> { // CHECK: return %[[R]] return %r : tensor<32xf32> } - -// CHECK-LABEL: @tensor_copy_pad -// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] -// CHECK-SAME: %[[PAD0:[[:alnum:]]+]] -// CHECK-SAME: %[[PAD1:[[:alnum:]]+]] -func.func @tensor_copy_pad(%arg0 : tensor, %pad0 : index, %pad1 : index) -> (tensor, !quidditch_snitch.dma_token) { - // CHECK: %[[COPY:.*]] = bufferization.to_memref %[[ARG0]] - // CHECK: %[[ZERO:.*]] = arith.constant 0 - // CHECK: %[[DIM0:.*]] = memref.dim %[[COPY]], %[[ZERO]] - // CHECK: %[[ONE:.*]] = arith.constant 1 - // CHECK: %[[DIM1:.*]] = memref.dim %[[COPY]], %[[ONE]] - // CHECK: %[[NEW_DIM0:.*]] = affine.apply #[[$MAP2]]()[%[[DIM0]], %[[PAD0]]] - // CHECK: %[[NEW_DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[DIM1]], %[[PAD1]]] - // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[NEW_DIM0]], %[[NEW_DIM1]]) - // CHECK: start_zero_mem_transfer %[[ALLOC]] - // CHECK: %[[UNPADDED:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[DIM0]], %[[DIM1]]] [1, 1] - // CHECK: %[[TOKEN:.*]] = quidditch_snitch.start_dma_transfer from %[[COPY]] - // CHECK-SAME: to %[[UNPADDED]] - %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 pad with zero to [%pad0, %pad1] : tensor -> tensor - // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] - // CHECK: return %[[TENSOR]], %[[TOKEN]] - return %r, %t : tensor, !quidditch_snitch.dma_token -} - -// CHECK-LABEL: @tensor_copy_pad_undef -// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] -// CHECK-SAME: %[[PAD0:[[:alnum:]]+]] -// CHECK-SAME: %[[PAD1:[[:alnum:]]+]] -func.func @tensor_copy_pad_undef(%arg0 : tensor, %pad0 : index, %pad1 : index) -> (tensor, !quidditch_snitch.dma_token) { - // CHECK: %[[COPY:.*]] = bufferization.to_memref %[[ARG0]] - // CHECK: %[[ZERO:.*]] = arith.constant 0 - // CHECK: %[[DIM0:.*]] = memref.dim %[[COPY]], %[[ZERO]] - // CHECK: %[[ONE:.*]] = arith.constant 1 - // CHECK: %[[DIM1:.*]] = memref.dim %[[COPY]], %[[ONE]] - // CHECK: %[[NEW_DIM0:.*]] = affine.apply #[[$MAP2]]()[%[[DIM0]], %[[PAD0]]] - // CHECK: %[[NEW_DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[DIM1]], %[[PAD1]]] - // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[NEW_DIM0]], %[[NEW_DIM1]]) - // CHECK-NOT: start_zero_mem_transfer - // CHECK: %[[UNPADDED:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[DIM0]], %[[DIM1]]] [1, 1] - // CHECK-NEXT: %[[TOKEN:.*]] = quidditch_snitch.start_dma_transfer from %[[COPY]] - // CHECK-SAME: to %[[UNPADDED]] - %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 pad with undef to [%pad0, %pad1] : tensor -> tensor - // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] - // CHECK: return %[[TENSOR]], %[[TOKEN]] - return %r, %t : tensor, !quidditch_snitch.dma_token -} diff --git a/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir b/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir index f049168..49aa8ad 100644 --- a/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir +++ b/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir @@ -36,22 +36,6 @@ func.func @identical_argument(%arg0 : i32) { return } -// CHECK-LABEL: @wait_gets_removed -func.func @wait_gets_removed() { - // CHECK-NEXT: return - %0 = quidditch_snitch.completed_token - quidditch_snitch.wait_for_dma_transfers %0 : !quidditch_snitch.dma_token - return -} - -// CHECK-LABEL: @noop_transfer -func.func @noop_transfer(%arg0 : memref) -> !quidditch_snitch.dma_token { - // CHECK-NEXT: %[[R:.*]] = quidditch_snitch.completed_token - // CHECK-NEXT: return %[[R]] - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref to %arg0 : memref - return %0 : !quidditch_snitch.dma_token -} - // CHECK-LABEL: @pipeline_dead_block_arg( func.func @pipeline_dead_block_arg(%tensor : tensor) { %c0 = arith.constant 0 : index @@ -91,97 +75,3 @@ func.func @pipeline_invariant(%tensor : tensor) { } return } - -// CHECK-LABEL: @tensor_wait_gets_removed -// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] -// CHECK-SAME: %[[ARG1:[[:alnum:]]+]] -func.func @tensor_wait_gets_removed(%arg0 : tensor, %arg1 : tensor) -> tensor { - // CHECK-NEXT: return %[[ARG1]] - %t = quidditch_snitch.completed_token - %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 : tensor to %arg1 using %t -> tensor - return %0 : tensor -} - -// CHECK-LABEL: @tensor_noop_transfer -// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] -func.func @tensor_noop_transfer(%arg0 : tensor) -> (tensor, !quidditch_snitch.dma_token) { - // CHECK: %[[T2:.*]] = quidditch_snitch.completed_token - // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy %[[ARG0]] - %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor -> tensor - // CHECK: %[[R2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[ARG0]] - // CHECK-SAME: to %[[R]] using %[[T]] - %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 : tensor to %r using %t -> tensor - - // CHECK-NOT: wait_for_tensor_copy - %r2, %t2 = quidditch_snitch.start_tensor_copy %0 to L1 : tensor -> tensor - - // CHECK: return %[[R2]], %[[T2]] - return %r2, %t2 : tensor, !quidditch_snitch.dma_token -} - -// CHECK-LABEL: @tensor_noop_pad -func.func @tensor_noop_pad(%arg0 : tensor) -> (tensor, !quidditch_snitch.dma_token) { - // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy - // CHECK-NOT: pad with - %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 pad with zero to [0] : tensor -> tensor - // CHECK-NEXT: return %[[R]], %[[T]] - return %r, %t : tensor, !quidditch_snitch.dma_token -} - -// CHECK-LABEL: @tensor_pad_constant -func.func @tensor_pad_constant(%arg0 : tensor) -> (tensor, !quidditch_snitch.dma_token) { - %zero = arith.constant 0 : index - // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy - // CHECK-NOT: pad with - %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 pad with zero to [%zero] : tensor -> tensor - // CHECK-NEXT: return %[[R]], %[[T]] - return %r, %t : tensor, !quidditch_snitch.dma_token -} - -// CHECK-LABEL: @tensor_noop_transfer_pad -// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] -func.func @tensor_noop_transfer_pad(%arg0 : tensor) -> (tensor, !quidditch_snitch.dma_token) { - // CHECK: %[[T2:.*]] = quidditch_snitch.completed_token - // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy %[[ARG0]] - %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 pad with zero to [1] : tensor -> tensor - // CHECK: %[[R2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[ARG0]] - // CHECK-SAME: to %[[R]] using %[[T]] - %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 : tensor to %r using %t -> tensor - - // CHECK-NOT: wait_for_tensor_copy - %r2, %t2 = quidditch_snitch.start_tensor_copy %0 to L1 : tensor -> tensor - - // CHECK: return %[[R2]], %[[T2]] - return %r2, %t2 : tensor, !quidditch_snitch.dma_token -} - -// CHECK-LABEL: @tensor_noop_transfer_pad_neg -// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] -func.func @tensor_noop_transfer_pad_neg(%arg0 : tensor) -> (tensor, !quidditch_snitch.dma_token) { - // CHECK: start_tensor_copy - // CHECK: wait_for_tensor_copy - // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy - // CHECK: return %[[R]], %[[T]] - - %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor -> tensor - %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 : tensor to %r using %t -> tensor - %r2, %t2 = quidditch_snitch.start_tensor_copy %0 to L1 pad with zero to [1] : tensor -> tensor - return %r2, %t2 : tensor, !quidditch_snitch.dma_token -} - -// CHECK-LABEL: @tensor_noop_transfer_same_padding -// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] -func.func @tensor_noop_transfer_same_padding(%arg0 : tensor) -> (tensor, !quidditch_snitch.dma_token) { - // CHECK: %[[T2:.*]] = quidditch_snitch.completed_token - // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy %[[ARG0]] - %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 pad with zero to [1] : tensor -> tensor - // CHECK: %[[R2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[ARG0]] - // CHECK-SAME: to %[[R]] using %[[T]] - %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 : tensor to %r using %t -> tensor - - // CHECK-NOT: wait_for_tensor_copy - %r2, %t2 = quidditch_snitch.start_tensor_copy %0 to L1 pad with zero to [1] : tensor -> tensor - - // CHECK: return %[[R2]], %[[T2]] - return %r2, %t2 : tensor, !quidditch_snitch.dma_token -} diff --git a/codegen/tests/Dialect/Snitch/IR/roundtrip.mlir b/codegen/tests/Dialect/Snitch/IR/roundtrip.mlir index 0f22a4d..e30a6d7 100644 --- a/codegen/tests/Dialect/Snitch/IR/roundtrip.mlir +++ b/codegen/tests/Dialect/Snitch/IR/roundtrip.mlir @@ -5,11 +5,5 @@ func.func @test(%arg0 : memref) { ^bb0(%arg1 : memref): } - quidditch_snitch.wait_for_dma_transfers return } - -func.func @test3(%arg0 : tensor) -> (tensor, !quidditch_snitch.dma_token) { - %0:2 = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor -> tensor - return %0#0, %0#1 : tensor, !quidditch_snitch.dma_token -} diff --git a/codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir b/codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir index 34d73dc..59254c0 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir @@ -34,7 +34,7 @@ func.func @test( // CHECK-NEXT: yield %[[ALLOCA0]] // CHECK: default // CHECK-NEXT: yield %[[ALLOCA1]] - // CHECK: %[[TOKEN:.*]] = quidditch_snitch.start_dma_transfer from %{{.*}} to %[[ALLOCA]] + // CHECK: %[[TOKEN:.*]] = dma.start_transfer from %{{.*}} to %[[ALLOCA]] // Full pipeline. // CHECK: %[[NEW_LB:.*]] = arith.addi %[[LB]], %[[STEP]] @@ -47,18 +47,18 @@ func.func @test( %subview_3 = memref.subview %9[%arg1, %arg0] [40, 100] [1, 1] : memref<1200x400xf64, strided<[400, 1], offset: ?>> to memref<40x100xf64, strided<[400, 1], offset: ?>> %alloca_4 = memref.alloca() {alignment = 64 : i64} : memref<40x100xf64, #quidditch_snitch.l1_encoding> - %16 = quidditch_snitch.start_dma_transfer from %subview_3 : memref<40x100xf64, strided<[400, 1], offset: ?>> to %alloca_4 : memref<40x100xf64, #quidditch_snitch.l1_encoding> - quidditch_snitch.pipeline_yield %alloca_4, %16 : memref<40x100xf64, #quidditch_snitch.l1_encoding>, !quidditch_snitch.dma_token + %16 = dma.start_transfer from %subview_3 : memref<40x100xf64, strided<[400, 1], offset: ?>> to %alloca_4 : memref<40x100xf64, #quidditch_snitch.l1_encoding> + quidditch_snitch.pipeline_yield %alloca_4, %16 : memref<40x100xf64, #quidditch_snitch.l1_encoding>, !dma.token }, { - ^bb0(%arg1: index, %arg2: memref<40x100xf64, #quidditch_snitch.l1_encoding>, %arg3: !quidditch_snitch.dma_token): + ^bb0(%arg1: index, %arg2: memref<40x100xf64, #quidditch_snitch.l1_encoding>, %arg3: !dma.token): // CHECK: %[[STAGE1_IV:.*]] = affine.apply #[[$MAP3]](%[[IV]]) // CHECK: memref.subview %{{.*}}[0, %[[STAGE1_IV]]] - // CHECK: wait_for_dma_transfers %[[YIELDED1]] + // CHECK: wait_for_transfers %[[YIELDED1]] // CHECK: linalg.matmul_transpose_b ins(%{{.*}}, %[[YIELDED0]] : {{.*}}) // CHECK: yield %[[NEXT_YIELDED]], %{{.*}} : %subview_3 = memref.subview %alloca[0, %arg1] [1, 40] [1, 1] : memref<1x1200xf64, #quidditch_snitch.l1_encoding> to memref<1x40xf64, strided<[1200, 1], offset: ?>, #quidditch_snitch.l1_encoding> - quidditch_snitch.wait_for_dma_transfers %arg3 : !quidditch_snitch.dma_token + dma.wait_for_transfers %arg3 : !dma.token linalg.matmul_transpose_b ins(%alloca2, %arg2 : memref<1x100xf64, #quidditch_snitch.l1_encoding>, memref<40x100xf64, #quidditch_snitch.l1_encoding>) outs(%out : memref<1x40xf64, #quidditch_snitch.l1_encoding>) @@ -66,7 +66,7 @@ func.func @test( // CHECK: %[[IV:.*]] = affine.apply #[[$MAP4]]() // CHECK: %[[STAGE1_IV:.*]] = affine.apply #[[$MAP5]]() // CHECK: memref.subview %{{.*}}[0, %[[STAGE1_IV]]] - // CHECK: wait_for_dma_transfers %[[LAST]]#1 + // CHECK: wait_for_transfers %[[LAST]]#1 // CHECK: linalg.matmul_transpose_b ins(%{{.*}}, %[[LAST]]#0 : {{.*}}) return } diff --git a/codegen/tests/Dialect/Snitch/Transforms/pipeline-copy-compute.mlir b/codegen/tests/Dialect/Snitch/Transforms/pipeline-copy-compute.mlir index b9c70aa..0190eaf 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/pipeline-copy-compute.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/pipeline-copy-compute.mlir @@ -16,21 +16,21 @@ func.func @test(%arg0: index, %extracted_slice : tensor<1x100xf64>, %14 : tensor // CHECK: pipeline %[[C0]] to %[[C1200]] step %[[C40]] inits(%[[EMPTY]]) %24 = scf.for %arg2 = %c0 to %c1200 step %c40 iter_args(%arg3 = %arg1) -> (tensor<1x1200xf64>) { // CHECK: ^{{.*}}(%[[IV:.*]]: index, %[[ITER:[[:alnum:]]+]]: - // CHECK: %[[RESULT0:.*]], %[[TOKEN0:.*]] = quidditch_snitch.start_tensor_copy %[[ARG1]] + // CHECK: %[[RESULT0:.*]], %[[TOKEN0:.*]] = dma.start_tensor_copy of %[[ARG1]] // CHECK: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG2]][%[[IV]], %[[ARG0]]] - // CHECK: %[[RESULT1:.*]], %[[TOKEN1:.*]] = quidditch_snitch.start_tensor_copy %[[SLICE1]] + // CHECK: %[[RESULT1:.*]], %[[TOKEN1:.*]] = dma.start_tensor_copy of %[[SLICE1]] // CHECK: %[[SLICE2:.*]] = tensor.extract_slice %[[ITER]][0, %[[IV]]] - // CHECK: %[[RESULT2:.*]], %[[TOKEN2:.*]] = quidditch_snitch.start_tensor_copy %[[SLICE2]] + // CHECK: %[[RESULT2:.*]], %[[TOKEN2:.*]] = dma.start_tensor_copy of %[[SLICE2]] // CHECK: pipeline_yield %[[ITER]], %[[RESULT0:.*]], %[[TOKEN0]], %[[SLICE1]], %[[RESULT1]], %[[TOKEN1]], %[[SLICE2]], %[[RESULT2]], %[[TOKEN2]] %extracted_slice_6 = tensor.extract_slice %14[%arg2, %arg0] [40, 100] [1, 1] : tensor<1200x400xf64> to tensor<40x100xf64> %extracted_slice_7 = tensor.extract_slice %arg3[0, %arg2] [1, 40] [1, 1] : tensor<1x1200xf64> to tensor<1x40xf64> - %result_8, %token_9 = quidditch_snitch.start_tensor_copy %extracted_slice to L1 : tensor<1x100xf64> -> tensor<1x100xf64> - %25 = quidditch_snitch.wait_for_tensor_copy of %extracted_slice : tensor<1x100xf64> to %result_8 using %token_9 -> tensor<1x100xf64> - %result_10, %token_11 = quidditch_snitch.start_tensor_copy %extracted_slice_6 to L1 : tensor<40x100xf64> -> tensor<40x100xf64> - %26 = quidditch_snitch.wait_for_tensor_copy of %extracted_slice_6 : tensor<40x100xf64> to %result_10 using %token_11 -> tensor<40x100xf64> - %result_12, %token_13 = quidditch_snitch.start_tensor_copy %extracted_slice_7 to L1 : tensor<1x40xf64> -> tensor<1x40xf64> - %27 = quidditch_snitch.wait_for_tensor_copy of %extracted_slice_7 : tensor<1x40xf64> to %result_12 using %token_13 -> tensor<1x40xf64> + %result_8, %token_9 = dma.start_tensor_copy of %extracted_slice to #quidditch_snitch.l1_encoding : tensor<1x100xf64> -> tensor<1x100xf64> + %25 = dma.wait_for_tensor_copy of %extracted_slice : tensor<1x100xf64> to %result_8 using %token_9 -> tensor<1x100xf64> + %result_10, %token_11 = dma.start_tensor_copy of %extracted_slice_6 to #quidditch_snitch.l1_encoding : tensor<40x100xf64> -> tensor<40x100xf64> + %26 = dma.wait_for_tensor_copy of %extracted_slice_6 : tensor<40x100xf64> to %result_10 using %token_11 -> tensor<40x100xf64> + %result_12, %token_13 = dma.start_tensor_copy of %extracted_slice_7 to #quidditch_snitch.l1_encoding : tensor<1x40xf64> -> tensor<1x40xf64> + %27 = dma.wait_for_tensor_copy of %extracted_slice_7 : tensor<1x40xf64> to %result_12 using %token_13 -> tensor<1x40xf64> // CHECK: ^{{.*}}( // CHECK-SAME: %[[IV:[[:alnum:]]+]] @@ -43,13 +43,13 @@ func.func @test(%arg0: index, %extracted_slice : tensor<1x100xf64>, %14 : tensor // CHECK-SAME: %[[SLICE2:[[:alnum:]]+]] // CHECK-SAME: %[[RESULT2:[[:alnum:]]+]] // CHECK-SAME: %[[TOKEN2:[[:alnum:]]+]] - // CHECK: %[[OPA:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[ARG1]] + // CHECK: %[[OPA:.*]] = dma.wait_for_tensor_copy of %[[ARG1]] // CHECK-SAME: to %[[RESULT0]] // CHECK-SAME: using %[[TOKEN0]] - // CHECK: %[[OPB:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[SLICE1]] + // CHECK: %[[OPB:.*]] = dma.wait_for_tensor_copy of %[[SLICE1]] // CHECK-SAME: to %[[RESULT1]] // CHECK-SAME: using %[[TOKEN1]] - // CHECK: %[[OPC:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[SLICE2]] + // CHECK: %[[OPC:.*]] = dma.wait_for_tensor_copy of %[[SLICE2]] // CHECK-SAME: to %[[RESULT2]] // CHECK-SAME: using %[[TOKEN2]] // CHECK: %[[RES:.*]] = linalg.matmul_transpose_b diff --git a/codegen/tests/Dialect/Snitch/Transforms/promote-operands-to-l1.mlir b/codegen/tests/Dialect/Snitch/Transforms/promote-operands-to-l1.mlir index 1c4c6de..a71201f 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/promote-operands-to-l1.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/promote-operands-to-l1.mlir @@ -6,16 +6,16 @@ func.func @test(%a : tensor<32x32xf32>, %b : tensor<32x32xf32>) -> tensor<32x32xf32> { // CHECK: %[[E:.*]] = bufferization.alloc_tensor %e = bufferization.alloc_tensor() : tensor<32x32xf32> - // CHECK: %[[A1:.*]], %[[TOKEN:.*]] = quidditch_snitch.start_tensor_copy %[[A]] to L1 - // CHECK: %[[A2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[A]] + // CHECK: %[[A1:.*]], %[[TOKEN:.*]] = dma.start_tensor_copy of %[[A]] to #quidditch_snitch.l1_encoding + // CHECK: %[[A2:.*]] = dma.wait_for_tensor_copy of %[[A]] // CHECK-SAME: to %[[A1]] // CEHCK-SAME: using %[[TOKEN]] - // CHECK: %[[B1:.*]], %[[TOKEN:.*]] = quidditch_snitch.start_tensor_copy %[[B]] to L1 - // CHECK: %[[B2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[B]] + // CHECK: %[[B1:.*]], %[[TOKEN:.*]] = dma.start_tensor_copy of %[[B]] to #quidditch_snitch.l1_encoding + // CHECK: %[[B2:.*]] = dma.wait_for_tensor_copy of %[[B]] // CHECK-SAME: to %[[B1]] // CHECK-SAME: using %[[TOKEN]] - // CHECK: %[[E1:.*]], %[[TOKEN:.*]] = quidditch_snitch.start_tensor_copy %[[E]] to L1 - // CHECK: %[[E2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[E]] + // CHECK: %[[E1:.*]], %[[TOKEN:.*]] = dma.start_tensor_copy of %[[E]] to #quidditch_snitch.l1_encoding + // CHECK: %[[E2:.*]] = dma.wait_for_tensor_copy of %[[E]] // CHECK-SAME: to %[[E1]] // CHECK-SAME: using %[[TOKEN]] // CHECK: linalg.matmul ins(%[[A2]], %[[B2]] : {{.*}}) outs(%[[E2]] : {{.*}}) diff --git a/codegen/tests/Dialect/Snitch/Transforms/promote-pads-to-l1.mlir b/codegen/tests/Dialect/Snitch/Transforms/promote-pads-to-l1.mlir index 4359b9e..2265a44 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/promote-pads-to-l1.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/promote-pads-to-l1.mlir @@ -4,9 +4,9 @@ // CHECK-SAME: %[[A:[[:alnum:]]+]]: tensor<32x32xf32> func.func @test_zero_f32(%a : tensor<32x32xf32>) -> tensor<33x33xf32> { %c = arith.constant 0.0 : f32 - // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy %[[A]] - // CHECK-SAME: pad with zero to [1, 1] - // CHECK: %[[R2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[A]] + // CHECK: %[[R:.*]], %[[T:.*]] = dma.start_tensor_copy of %[[A]] + // CHECK-SAME: pad with zero by [1, 1] + // CHECK: %[[R2:.*]] = dma.wait_for_tensor_copy of %[[A]] // CHECK-SAME: to %[[R]] // CHECK-SAME: using %[[T]] %0 = tensor.pad %a low[0, 0] high[1, 1] { @@ -21,9 +21,9 @@ func.func @test_zero_f32(%a : tensor<32x32xf32>) -> tensor<33x33xf32> { // CHECK-SAME: %[[A:[[:alnum:]]+]]: tensor<32x32xf32> func.func @test_poison(%a : tensor<32x32xf32>) -> tensor<33x33xf32> { %c = ub.poison : f32 - // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy %[[A]] - // CHECK-SAME: pad with undef to [1, 1] - // CHECK: %[[R2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[A]] + // CHECK: %[[R:.*]], %[[T:.*]] = dma.start_tensor_copy of %[[A]] + // CHECK-SAME: pad with undef by [1, 1] + // CHECK: %[[R2:.*]] = dma.wait_for_tensor_copy of %[[A]] // CHECK-SAME: to %[[R]] // CHECK-SAME: using %[[T]] %0 = tensor.pad %a low[0, 0] high[1, 1] { diff --git a/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir b/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir index b579ec9..351b2d7 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir @@ -14,12 +14,12 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { %a_l1 = memref.view %view[%c0][] : memref<512xi8> to memref<32xf32> %b_l1 = memref.view %view[%c256][] : memref<512xi8> to memref<32xf32> - // CHECK-NEXT: quidditch_snitch.completed_token - // CHECK-NEXT: quidditch_snitch.completed_token + // CHECK-NEXT: dma.completed_token + // CHECK-NEXT: dma.completed_token // CHECK-NEXT: quidditch_snitch.barrier - quidditch_snitch.start_dma_transfer from %a : memref<32xf32> to %a_l1 : memref<32xf32> - %t = quidditch_snitch.start_dma_transfer from %b : memref<32xf32> to %b_l1 : memref<32xf32> - quidditch_snitch.wait_for_dma_transfers %t : !quidditch_snitch.dma_token + dma.start_transfer from %a : memref<32xf32> to %a_l1 : memref<32xf32> + %t = dma.start_transfer from %b : memref<32xf32> to %b_l1 : memref<32xf32> + dma.wait_for_transfers %t : !dma.token // CHECK-NEXT: microkernel // CHECK: } @@ -31,31 +31,31 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { // CHECK-NEXT: quidditch_snitch.microkernel_fence // CHECK-NEXT: quidditch_snitch.barrier - // CHECK-NEXT: quidditch_snitch.completed_token - %t2 = quidditch_snitch.start_dma_transfer from %b_l1 : memref<32xf32> to %b : memref<32xf32> + // CHECK-NEXT: dma.completed_token + %t2 = dma.start_transfer from %b_l1 : memref<32xf32> to %b : memref<32xf32> // CHECK-NEXT: quidditch_snitch.barrier - quidditch_snitch.wait_for_dma_transfers %t2 : !quidditch_snitch.dma_token + dma.wait_for_transfers %t2 : !dma.token // CHECK: scf.if - %r:2 = scf.if %cond -> (!quidditch_snitch.dma_token, index) { - // CHECK-NEXT: %[[C:.*]] = quidditch_snitch.completed_token - %t3 = quidditch_snitch.start_dma_transfer from %b_l1 : memref<32xf32> to %b : memref<32xf32> + %r:2 = scf.if %cond -> (!dma.token, index) { + // CHECK-NEXT: %[[C:.*]] = dma.completed_token + %t3 = dma.start_transfer from %b_l1 : memref<32xf32> to %b : memref<32xf32> // CHECK-NEXT: %[[I:.*]] = quidditch_snitch.compute_core_index %i = quidditch_snitch.compute_core_index // CHECK-NEXT: yield %[[C]], %[[I]] - scf.yield %t3, %i : !quidditch_snitch.dma_token, index + scf.yield %t3, %i : !dma.token, index } else { // CHECK-NEXT: else - // CHECK-NEXT: %[[C:.*]] = quidditch_snitch.completed_token - %c = quidditch_snitch.completed_token + // CHECK-NEXT: %[[C:.*]] = dma.completed_token + %c = dma.completed_token // CHECK-NEXT: %[[I:.*]] = arith.constant %i = arith.constant 1 : index // CHECK-NEXT: yield %[[C]], %[[I]] - scf.yield %c, %i : !quidditch_snitch.dma_token, index + scf.yield %c, %i : !dma.token, index } // CHECK: quidditch_snitch.barrier - quidditch_snitch.wait_for_dma_transfers %r#0 : !quidditch_snitch.dma_token + dma.wait_for_transfers %r#0 : !dma.token // CHECK-NEXT: return return } @@ -67,25 +67,25 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { // CHECK: memref.view // CHECK-NEXT: memref.view -// CHECK-NEXT: quidditch_snitch.start_dma_transfer -// CHECK-NEXT: quidditch_snitch.start_dma_transfer -// CHECK-NEXT: quidditch_snitch.wait_for_dma_transfers +// CHECK-NEXT: dma.start_transfer +// CHECK-NEXT: dma.start_transfer +// CHECK-NEXT: dma.wait_for_transfers // CHECK-NEXT: quidditch_snitch.barrier // CHECK-NEXT: quidditch_snitch.barrier -// CHECK-NEXT: quidditch_snitch.start_dma_transfer -// CHECK-NEXT: quidditch_snitch.wait_for_dma_transfers +// CHECK-NEXT: dma.start_transfer +// CHECK-NEXT: dma.wait_for_transfers // CHECK-NEXT: quidditch_snitch.barrier // CHECK-NEXT: scf.if -// CHECK-NEXT: quidditch_snitch.start_dma_transfer +// CHECK-NEXT: dma.start_transfer // CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 // CHECK-NEXT: yield %{{.*}}, %[[ZERO]] : // CHECK-NEXT: else // CHECK-NEXT: completed_token // CHECK-NEXT: arith.constant // CHECK-NEXT: yield -// CHECK: quidditch_snitch.wait_for_dma_transfers +// CHECK: dma.wait_for_transfers // CHECK-NEXT: quidditch_snitch.barrier // CHECK-NEXT: return diff --git a/codegen/tools/CMakeLists.txt b/codegen/tools/CMakeLists.txt index 275568b..e3da2fb 100644 --- a/codegen/tools/CMakeLists.txt +++ b/codegen/tools/CMakeLists.txt @@ -4,6 +4,7 @@ target_link_libraries(quidditch-opt MLIROptLib Quidditch::Conversion::ConvertSnitchToLLVM Quidditch::Conversion::ConvertToRISCV + Quidditch::Dialect::DMA::Extensions::DMACoreSpecializationOpInterfaceImpl Quidditch::Dialect::Snitch::IR::QuidditchSnitchDialect Quidditch::Dialect::Snitch::Transforms::Passes Quidditch::Target::Passes diff --git a/codegen/tools/quidditch-opt.cpp b/codegen/tools/quidditch-opt.cpp index a949287..68cd9e9 100644 --- a/codegen/tools/quidditch-opt.cpp +++ b/codegen/tools/quidditch-opt.cpp @@ -2,6 +2,8 @@ #include #include "Quidditch/Conversion/Passes.h" +#include "Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.h" +#include "Quidditch/Dialect/DMA/IR/DMADialect.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h" #include "Quidditch/Dialect/Snitch/Transforms/Passes.h" #include "Quidditch/Target/Passes.h" @@ -26,8 +28,10 @@ int main(int argc, char **argv) { // Be lazy and support all upstream dialects as input dialects. DialectRegistry registry; + quidditch::dma::registerDMACoreSpecializationOpInterface(registry); iree_compiler::registerAllDialects(registry); - registry.insert(); + registry.insert(); quidditch::registerPasses(); quidditch::registerConversionPasses();