Skip to content

Commit

Permalink
[SnitchToLLVM] Properly detect contiguous MemRefs (#56)
Browse files Browse the repository at this point in the history
Contiguous MemRefs can be copied using just `snrt_dma_start_1d` which
simplifies some of the required setup for the DMA. This PR properly
implements the detection of such contiguous MemRefs by first checking
for an identity layout and otherwise, checking if the sided layout
describes an identity layout in all but offset.
  • Loading branch information
zero9178 authored Jun 27, 2024
1 parent 60afb0c commit dad35a1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 15 deletions.
53 changes: 38 additions & 15 deletions codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,34 @@ struct L1MemoryViewOpLowering : ConvertOpToLLVMPattern<L1MemoryViewOp> {
}
};

/// Returns true if this MemRef type is known to have a fully contiguous layout.
bool isContiguous(MemRefType memRefType) {
MemRefLayoutAttrInterface layout = memRefType.getLayout();
if (!layout || layout.isIdentity())
return true;

// It is impossible to statically determine contiguity with dynamic strides.
auto strided = dyn_cast<StridedLayoutAttr>(layout);
if (!strided || llvm::any_of(strided.getStrides(), ShapedType::isDynamic))
return false;

// Calculate what the strides would be if it had an identity layout and check
// that they match.
ArrayRef<int64_t> shape = memRefType.getShape();
ArrayRef<int64_t> strides = strided.getStrides();
std::uint64_t currentIdentityStride = 1;
for (auto [dim, stride] : llvm::zip_equal(llvm::reverse(shape.drop_front()),
strides.drop_front())) {
if (currentIdentityStride != stride)
return false;

if (ShapedType::isDynamic(dim))
return false;
currentIdentityStride *= dim;
}
return currentIdentityStride == strided.getStrides().front();
}

struct StartDMATransferOp1DLowering
: ConvertOpToLLVMPattern<StartDMATransferOp> {

Expand All @@ -70,21 +98,11 @@ struct StartDMATransferOp1DLowering
StartDMATransferOp1DLowering(LLVM::LLVMFuncOp dmaStart1DFunc,
const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter, /*benefit=*/2),
dmaStart1DFunc(dmaStart1DFunc) {
setHasBoundedRewriteRecursion();
}
dmaStart1DFunc(dmaStart1DFunc) {}

LogicalResult match(StartDMATransferOp op) const override {
MemRefLayoutAttrInterface sourceLayout =
op.getSource().getType().getLayout();
MemRefLayoutAttrInterface destLayout = op.getDest().getType().getLayout();
if (sourceLayout && !sourceLayout.isIdentity())
return failure();

if (destLayout && !destLayout.isIdentity())
return failure();

return success();
return success(isContiguous(op.getSource().getType()) &&
isContiguous(op.getDest().getType()));
}

void rewrite(StartDMATransferOp op, OpAdaptor adaptor,
Expand All @@ -107,8 +125,13 @@ struct StartDMATransferOp1DLowering
SmallVector<Value> sizes;
SmallVector<Value> strides;
Value totalSize;
getMemRefDescriptorSizes(op->getLoc(), op.getSource().getType(),
dynamicSizes, rewriter, sizes, strides, totalSize);
getMemRefDescriptorSizes(
op->getLoc(),
// Offsets are not considered an identity layout.
// Get rid of the layout entirely for the size calculation.
MemRefType::get(sourceMemRef.getShape(), sourceMemRef.getElementType(),
nullptr, sourceMemRef.getMemorySpace()),
dynamicSizes, rewriter, sizes, strides, totalSize);

rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, dmaStart1DFunc,
ValueRange{
Expand Down
17 changes: 17 additions & 0 deletions codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer_1d.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,20 @@ func.func @test(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>) -> !quidditch_snit
// CHECK: return %[[C]]
return %0 : !quidditch_snitch.dma_token
}

// CHECK-LABEL: @test2
func.func @test2(%arg0 : memref<?xf32>, %arg1 : memref<?xf32, strided<[1], offset: ?>>) -> !quidditch_snitch.dma_token {
// CHECK: llvm.call @snrt_dma_start_1d(
%0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<?xf32> to %arg1 : memref<?xf32, strided<[1], offset: ?>>
// CHECK: llvm.call @snrt_dma_start_1d(
%1 = quidditch_snitch.start_dma_transfer from %arg1 : memref<?xf32, strided<[1], offset: ?>> to %arg0 : memref<?xf32>
return %0 : !quidditch_snitch.dma_token
}

// CHECK-LABEL: @test3
func.func @test3(%arg0 : memref<?x4xf32>, %arg1 : memref<?x4xf32, strided<[4, 1], offset: ?>>) -> !quidditch_snitch.dma_token {
// CHECK: llvm.call @snrt_dma_start_1d(
%0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<?x4xf32> to %arg1 : memref<?x4xf32, strided<[4, 1], offset: ?>>
return %0 : !quidditch_snitch.dma_token
}

0 comments on commit dad35a1

Please sign in to comment.