From 500eddc5ebf3988705797121ebaef8cb38c42062 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Tue, 27 Aug 2024 16:53:01 +0100 Subject: [PATCH] better dma transfers --- .../Conversion/ConvertSnitchToLLVM.cpp | 165 ++++++++---------- .../ConvertSnitchToLLVM/dma_transfer.mlir | 81 ++++++--- 2 files changed, 130 insertions(+), 116 deletions(-) diff --git a/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp b/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp index 19ec1b3..e42af9b 100644 --- a/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp +++ b/codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp @@ -178,70 +178,26 @@ struct StartDMATransferOp2DLowering MemRefType sourceMemRef = op.getSource().getType(); MemRefType destMemRef = op.getDest().getType(); - StridedLayoutAttr sourceStridesAttr = - dyn_cast_or_null(sourceMemRef.getLayout()); - if (!sourceStridesAttr) { - if (sourceMemRef.getLayout() && !sourceMemRef.getLayout().isIdentity()) - return failure(); - - sourceStridesAttr = identityStride(sourceMemRef); - } - - StridedLayoutAttr destStridesAttr = - dyn_cast_or_null(destMemRef.getLayout()); - if (!destStridesAttr) { - if (destMemRef.getLayout() && !destMemRef.getLayout().isIdentity()) - return failure(); - - destStridesAttr = identityStride(destMemRef); - } - // Compute the size of the contiguous inner loop common to both MemRefs and // "shave" it off the ends of the shapes and strides. The remaining shapes // and strides are considered our outer dimensions. - int64_t innerSize = 1; - ArrayRef shape = sourceMemRef.getShape(); - ArrayRef sourceStrides = sourceStridesAttr.getStrides(); - ArrayRef destStrides = destStridesAttr.getStrides(); - assert(shape.size() == sourceStrides.size() && - sourceStrides.size() == destStrides.size()); - for (; shape.size() > 1; shape = shape.drop_back(), - sourceStrides = sourceStrides.drop_back(), - destStrides = destStrides.drop_back()) { - int64_t dim = shape.back(); - if (dim == 1) - continue; - - int64_t sourceStride = sourceStrides.back(); - int64_t destStride = destStrides.back(); - if (sourceStride != destStride) - break; - - if (innerSize != sourceStride) - break; - - if (ShapedType::isDynamic(dim)) - break; - - innerSize *= dim; - } - - MemRefDescriptor sourceDescriptor(adaptor.getSource()); - MemRefDescriptor destDescriptor(adaptor.getDest()); - - Value source = sourceDescriptor.bufferPtr( - rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType()); - Value dest = destDescriptor.bufferPtr( - rewriter, op->getLoc(), *getTypeConverter(), op.getDest().getType()); + FailureOr sourceNonContiguous = + getNumNonContiguousOuterDims(sourceMemRef); + FailureOr destNonContiguous = + getNumNonContiguousOuterDims(destMemRef); + if (failed(sourceNonContiguous) || failed(destNonContiguous)) + return failure(); + size_t sharedNonContiguous = + std::max(*sourceNonContiguous, *destNonContiguous); + if (sharedNonContiguous == 0) + return failure(); Value elementSize = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(llvm::divideCeil( op.getSource().getType().getElementTypeBitWidth(), 8))); - Value contiguousSize = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(innerSize)); - contiguousSize = - rewriter.create(op->getLoc(), contiguousSize, elementSize); + SmallVector sizes = + memref::getMixedSizes(rewriter, op->getLoc(), op.getSource()); // Build a loop nest iterating over all outer dimensions - 1 and adjusts the // source and destination pointers accordingly. The inner-most outer @@ -251,59 +207,83 @@ struct StartDMATransferOp2DLowering SmallVector steps; Value zeroIndex = rewriter.create(op.getLoc(), 0); Value oneIndex = rewriter.create(op.getLoc(), 1); - for (size_t index : llvm::seq(shape.size() - 1)) { + for (size_t index : llvm::seq(sharedNonContiguous - 1)) { lowerBounds.push_back(zeroIndex); steps.push_back(oneIndex); - Value dim = typeConverter->materializeSourceConversion( - rewriter, op->getLoc(), rewriter.getIndexType(), - sourceDescriptor.size(rewriter, op->getLoc(), index)); - upperBounds.push_back(dim); + upperBounds.push_back(getValueOrCreateConstantIndexOp( + rewriter, op->getLoc(), sizes[index])); } - Type tokenType = typeConverter->convertType(op.getType()); + Value contiguousSize; + for (auto index : + llvm::seq(sharedNonContiguous, sourceMemRef.getRank())) { + Value dim = + getValueOrCreateConstantIndexOp(rewriter, op->getLoc(), sizes[index]); + if (!contiguousSize) { + contiguousSize = dim; + continue; + } + contiguousSize = + rewriter.create(op->getLoc(), contiguousSize, dim); + } + contiguousSize = typeConverter->materializeTargetConversion( + rewriter, op->getLoc(), getIndexType(), contiguousSize); + contiguousSize = + rewriter.create(op->getLoc(), contiguousSize, elementSize); + Value completedToken = rewriter.create(op->getLoc()); - completedToken = typeConverter->materializeTargetConversion( - rewriter, op->getLoc(), tokenType, completedToken); scf::LoopNest loopNest = scf::buildLoopNest( rewriter, op->getLoc(), lowerBounds, upperBounds, steps, completedToken, [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { - auto linearizeOffset = [&](MemRefDescriptor descriptor) { - Value offset = - rewriter.create(loc, rewriter.getI32Type()); - for (auto [index, iv] : llvm::enumerate(ivs)) { - Value increment = rewriter.create( - loc, - typeConverter->materializeTargetConversion( - builder, op->getLoc(), - typeConverter->convertType(iv.getType()), iv), - descriptor.stride(builder, loc, index)); - offset = rewriter.create(loc, offset, increment); - } - return offset; - }; - - Value sourceAdjusted = rewriter.create( - loc, source.getType(), - typeConverter->convertType(sourceMemRef.getElementType()), source, - linearizeOffset(sourceDescriptor)); - Value destAdjusted = rewriter.create( - loc, dest.getType(), - typeConverter->convertType(destMemRef.getElementType()), dest, - linearizeOffset(destDescriptor)); + SmallVector offsets = ivs; + SmallVector subSizes(sharedNonContiguous - 1, + rewriter.getIndexAttr(1)); + for (unsigned i : llvm::seq(sharedNonContiguous - 1, + sourceMemRef.getRank())) { + offsets.push_back(rewriter.getIndexAttr(0)); + subSizes.push_back(sizes[i]); + } + SmallVector strides(sourceMemRef.getRank(), + rewriter.getIndexAttr(1)); + + TypedValue sourceMemRefSlice = + rewriter.create(loc, op.getSource(), offsets, + subSizes, strides); + TypedValue destMemRefSlice = + rewriter.create(loc, op.getDest(), offsets, + subSizes, strides); + + auto sourceDescriptor = + MemRefDescriptor(typeConverter->materializeTargetConversion( + rewriter, op->getLoc(), + typeConverter->convertType(sourceMemRefSlice.getType()), + sourceMemRefSlice)); + auto destDescriptor = + MemRefDescriptor(typeConverter->materializeTargetConversion( + rewriter, op->getLoc(), + typeConverter->convertType(destMemRefSlice.getType()), + destMemRefSlice)); + + Value sourceAdjusted = sourceDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), + sourceMemRefSlice.getType()); + Value destAdjusted = destDescriptor.bufferPtr( + rewriter, op->getLoc(), *getTypeConverter(), + destMemRefSlice.getType()); Value sourceStride = - sourceDescriptor.stride(builder, loc, sourceStrides.size() - 1); + sourceDescriptor.stride(builder, loc, sharedNonContiguous - 1); sourceStride = rewriter.create( op->getLoc(), sourceStride, elementSize); Value destStride = - destDescriptor.stride(builder, loc, destStrides.size() - 1); + destDescriptor.stride(builder, loc, sharedNonContiguous - 1); destStride = rewriter.create(op->getLoc(), destStride, elementSize); Value outerLoopSize = - sourceDescriptor.size(builder, loc, shape.size() - 1); + sourceDescriptor.size(builder, loc, sharedNonContiguous - 1); return {builder .create(loc, dmaStart2DFunc, ValueRange{ @@ -317,7 +297,10 @@ struct StartDMATransferOp2DLowering .getResult()}; }); - rewriter.replaceOp(op, loopNest.results.front()); + Type tokenType = typeConverter->convertType(op.getType()); + rewriter.replaceOp( + op, typeConverter->materializeTargetConversion( + rewriter, op->getLoc(), tokenType, loopNest.results.front())); return success(); } }; diff --git a/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer.mlir b/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer.mlir index ba481fc..3001bf8 100644 --- a/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer.mlir +++ b/codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer.mlir @@ -56,9 +56,9 @@ func.func private @test4(%arg0 : memref<1x4xf32>, %arg1 : memref<1x4xf32, stride // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG0_PTR:[[:alnum:]]+]] // CHECK-SAME: %{{[[:alnum:]]+}} -// CHECK-SAME: %[[ARG0_SIZE:[[:alnum:]]+]] // CHECK-SAME: %{{[[:alnum:]]+}} -// CHECK-SAME: %[[ARG0_STRIDE_N:[[:alnum:]]+]] +// CHECK-SAME: %{{[[:alnum:]]+}} +// CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG1_PTR:[[:alnum:]]+]] @@ -67,11 +67,13 @@ func.func private @test4(%arg0 : memref<1x4xf32>, %arg1 : memref<1x4xf32, stride // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]] func.func private @test5(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[8, 1], offset: 0>>) -> !quidditch_snitch.dma_token { - // CHECK: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32) - // CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[ELEMENT_WIDTH]], %[[ELEMENT_WIDTH]] - // CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[ARG0_STRIDE_N]], %[[ELEMENT_WIDTH]] + // CHECK-DAG: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32) + // CHECK-DAG: %[[FOUR_INDEX:.*]] = llvm.mlir.constant(4 : index) + // CHECK-DAG: %[[TWO:.*]] = llvm.mlir.constant(2 : index) + // CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[FOUR_INDEX]], %[[ELEMENT_WIDTH]] + // CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[FOUR_INDEX]], %[[ELEMENT_WIDTH]] // CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]] - // CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[ARG0_SIZE]]) + // CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[TWO]]) %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[8, 1], offset: 0>> return %0 : !quidditch_snitch.dma_token } @@ -80,14 +82,14 @@ func.func private @test5(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, stride // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG0_PTR:[[:alnum:]]+]] // CHECK-SAME: %{{[[:alnum:]]+}} -// CHECK-SAME: %[[ARG0_SIZE:[[:alnum:]]+]] -// CHECK-SAME: %[[DIM1:[[:alnum:]]+]] +// CHECK-SAME: %{{[[:alnum:]]+}} +// CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG0_STRIDE0:[[:alnum:]]+]] // CHECK-SAME: %[[ARG0_STRIDE_N:[[:alnum:]]+]] // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} -// CHECK-SAME: %[[ARG1_ALIGNED_PTR:[[:alnum:]]+]] +// CHECK-SAME: %[[ARG1_PTR:[[:alnum:]]+]] // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} @@ -98,27 +100,30 @@ func.func private @test6(%arg0 : memref<3x2x4xf32>, %arg1 : memref<3x2x4xf32, st // CHECK-DAG: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32) // CHECK-DAG: %[[ZERO32:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i32 + // CHECK-DAG: %[[EIGHT:.*]] = llvm.mlir.constant(8 : {{.*}}) : i32 + // CHECK-DAG: %[[SIXTEEN:.*]] = llvm.mlir.constant(16 : {{.*}}) : i32 // CHECK-DAG: %[[ONE:.*]] = llvm.mlir.constant(1 : {{.*}}) : i32 - // CHECK: %[[ARG1_PTR:.*]] = llvm.getelementptr %[[ARG1_ALIGNED_PTR]][2] - // CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[ELEMENT_WIDTH]], %[[ELEMENT_WIDTH]] + // CHECK-DAG: %[[TWO:.*]] = llvm.mlir.constant(2 : {{.*}}) : i32 + // CHECK-DAG: %[[THREE:.*]] = llvm.mlir.constant(3 : {{.*}}) : i32 + // CHECK-DAG: %[[FOUR_INDEX:.*]] = llvm.mlir.constant(4 : index) + + // CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[FOUR_INDEX]], %[[ELEMENT_WIDTH]] // CHECK: llvm.br ^[[BB1:.*]](%[[ZERO]], %[[ZERO32]] // CHECK: ^[[BB1]](%[[IV1:.*]]: i32, %[[IV2:.*]]: i32): - // CHECK: %[[COND:.*]] = llvm.icmp "slt" %[[IV1]], %[[ARG0_SIZE]] + // CHECK: %[[COND:.*]] = llvm.icmp "slt" %[[IV1]], %[[THREE]] // CHECK: llvm.cond_br %[[COND]], ^[[BODY:.*]], ^[[EXIT:[[:alnum:]]+]] // CHECK: ^[[BODY]]: - // CHECK: %[[MUL:.*]] = llvm.mul %[[IV1]], %[[ARG0_STRIDE0]] - // CHECK: %[[ARG0_OFFSET1:.*]] = llvm.add %[[MUL]], %[[ZERO32]] - // CHECK: %[[ARG0_ADJUSTED:.*]] = llvm.getelementptr %[[ARG0_PTR]][%[[ARG0_OFFSET1]]] + // CHECK: %[[MUL1:.*]] = llvm.mul %[[IV1]], %[[EIGHT]] + // CHECK: %[[MUL2:.*]] = llvm.mul %[[IV1]], %[[SIXTEEN]] + // CHECK: %[[ARG0_OFFSET1:.*]] = llvm.add %[[MUL2]], %[[TWO]] + // CHECK: %[[ARG0_ADJUSTED:.*]] = llvm.getelementptr %[[ARG0_PTR]][%[[MUL1]]] + // CHECK: %[[ARG1_ADJUSTED:.*]] = llvm.getelementptr %[[ARG1_PTR]][%[[ARG0_OFFSET1]]] - // CHECK: %[[MUL:.*]] = llvm.mul %[[IV1]], %[[ARG1_STRIDE0]] - // CHECK: %[[ARG1_OFFSET1:.*]] = llvm.add %[[MUL]], %[[ZERO32]] - // CHECK: %[[ARG1_ADJUSTED:.*]] = llvm.getelementptr %[[ARG1_PTR]][%[[ARG1_OFFSET1]]] - - // CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[ARG0_STRIDE_N]], %[[ELEMENT_WIDTH]] - // CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]] - // CHECK: %[[RES:.*]] = llvm.call @snrt_dma_start_2d(%[[ARG1_ADJUSTED]], %[[ARG0_ADJUSTED]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[DIM1]]) + // CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[FOUR_INDEX]], %[[ELEMENT_WIDTH]] + // CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[EIGHT]], %[[ELEMENT_WIDTH]] + // CHECK: %[[RES:.*]] = llvm.call @snrt_dma_start_2d(%[[ARG1_ADJUSTED]], %[[ARG0_ADJUSTED]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[TWO]]) // CHECK: %[[INV:.*]] = llvm.add %[[IV1]], %[[ONE]] // CHECK: llvm.br ^[[BB1]](%[[INV]], %[[RES]] @@ -142,12 +147,38 @@ func.func private @test6(%arg0 : memref<3x2x4xf32>, %arg1 : memref<3x2x4xf32, st // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %{{[[:alnum:]]+}} // CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]] -func.func private @dynamic_strides(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[?, 1], offset: 0>>) -> !quidditch_snitch.dma_token { +func.func private @dynamic_strides(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[?, 1]>>) -> !quidditch_snitch.dma_token { + // CHECK-DAG: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32) + // CHECK-DAG: %[[FOUR:.*]] = llvm.mlir.constant(4 : index) + // CHECK-DAG: %[[TWO:.*]] = llvm.mlir.constant(2 : index) + // CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[FOUR]], %[[ELEMENT_WIDTH]] + // CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[FOUR]], %[[ELEMENT_WIDTH]] + // CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]] + // CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[TWO]]) + %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[?, 1]>> + return %0 : !quidditch_snitch.dma_token +} + +// CHECK-LABEL: @contigious_dynamic_inner +// CHECK-SAME: %{{[[:alnum:]]+}} +// CHECK-SAME: %[[ARG0_PTR:[[:alnum:]]+]] +// CHECK-SAME: %{{[[:alnum:]]+}} +// CHECK-SAME: %[[ARG0_SIZE:[[:alnum:]]+]] +// CHECK-SAME: %[[ARG0_STRIDE_0:[[:alnum:]]+]] +// CHECK-SAME: %[[ARG0_STRIDE_N:[[:alnum:]]+]] +// CHECK-SAME: %{{[[:alnum:]]+}} +// CHECK-SAME: %{{[[:alnum:]]+}} +// CHECK-SAME: %[[ARG1_PTR:[[:alnum:]]+]] +// CHECK-SAME: %{{[[:alnum:]]+}} +// CHECK-SAME: %{{[[:alnum:]]+}} +// CHECK-SAME: %{{[[:alnum:]]+}} +// CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]] +func.func private @contigious_dynamic_inner(%arg0 : memref, %arg1 : memref>) -> !quidditch_snitch.dma_token { // CHECK: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32) - // CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[ELEMENT_WIDTH]], %[[ELEMENT_WIDTH]] + // CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[ARG0_STRIDE_0]], %[[ELEMENT_WIDTH]] // CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[ARG0_STRIDE_N]], %[[ELEMENT_WIDTH]] // CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]] // CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[ARG0_SIZE]]) - %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[?, 1], offset: 0>> + %0 = quidditch_snitch.start_dma_transfer from %arg0 : memref to %arg1 : memref> return %0 : !quidditch_snitch.dma_token }