Skip to content

Commit

Permalink
better dma transfers
Browse files Browse the repository at this point in the history
  • Loading branch information
zero9178 committed Aug 27, 2024
1 parent 8303280 commit 500eddc
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 116 deletions.
165 changes: 74 additions & 91 deletions codegen/compiler/src/Quidditch/Conversion/ConvertSnitchToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,70 +178,26 @@ struct StartDMATransferOp2DLowering
MemRefType sourceMemRef = op.getSource().getType();
MemRefType destMemRef = op.getDest().getType();

StridedLayoutAttr sourceStridesAttr =
dyn_cast_or_null<StridedLayoutAttr>(sourceMemRef.getLayout());
if (!sourceStridesAttr) {
if (sourceMemRef.getLayout() && !sourceMemRef.getLayout().isIdentity())
return failure();

sourceStridesAttr = identityStride(sourceMemRef);
}

StridedLayoutAttr destStridesAttr =
dyn_cast_or_null<StridedLayoutAttr>(destMemRef.getLayout());
if (!destStridesAttr) {
if (destMemRef.getLayout() && !destMemRef.getLayout().isIdentity())
return failure();

destStridesAttr = identityStride(destMemRef);
}

// Compute the size of the contiguous inner loop common to both MemRefs and
// "shave" it off the ends of the shapes and strides. The remaining shapes
// and strides are considered our outer dimensions.
int64_t innerSize = 1;
ArrayRef<int64_t> shape = sourceMemRef.getShape();
ArrayRef<int64_t> sourceStrides = sourceStridesAttr.getStrides();
ArrayRef<int64_t> destStrides = destStridesAttr.getStrides();
assert(shape.size() == sourceStrides.size() &&
sourceStrides.size() == destStrides.size());
for (; shape.size() > 1; shape = shape.drop_back(),
sourceStrides = sourceStrides.drop_back(),
destStrides = destStrides.drop_back()) {
int64_t dim = shape.back();
if (dim == 1)
continue;

int64_t sourceStride = sourceStrides.back();
int64_t destStride = destStrides.back();
if (sourceStride != destStride)
break;

if (innerSize != sourceStride)
break;

if (ShapedType::isDynamic(dim))
break;

innerSize *= dim;
}

MemRefDescriptor sourceDescriptor(adaptor.getSource());
MemRefDescriptor destDescriptor(adaptor.getDest());

Value source = sourceDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getSource().getType());
Value dest = destDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(), op.getDest().getType());
FailureOr<size_t> sourceNonContiguous =
getNumNonContiguousOuterDims(sourceMemRef);
FailureOr<size_t> destNonContiguous =
getNumNonContiguousOuterDims(destMemRef);
if (failed(sourceNonContiguous) || failed(destNonContiguous))
return failure();
size_t sharedNonContiguous =
std::max(*sourceNonContiguous, *destNonContiguous);
if (sharedNonContiguous == 0)
return failure();

Value elementSize = rewriter.create<LLVM::ConstantOp>(
op->getLoc(),
rewriter.getI32IntegerAttr(llvm::divideCeil(
op.getSource().getType().getElementTypeBitWidth(), 8)));
Value contiguousSize = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), rewriter.getI32IntegerAttr(innerSize));
contiguousSize =
rewriter.create<LLVM::MulOp>(op->getLoc(), contiguousSize, elementSize);
SmallVector<OpFoldResult> sizes =
memref::getMixedSizes(rewriter, op->getLoc(), op.getSource());

// Build a loop nest iterating over all outer dimensions - 1 and adjusts the
// source and destination pointers accordingly. The inner-most outer
Expand All @@ -251,59 +207,83 @@ struct StartDMATransferOp2DLowering
SmallVector<Value> steps;
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
Value oneIndex = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
for (size_t index : llvm::seq(shape.size() - 1)) {
for (size_t index : llvm::seq(sharedNonContiguous - 1)) {
lowerBounds.push_back(zeroIndex);
steps.push_back(oneIndex);
Value dim = typeConverter->materializeSourceConversion(
rewriter, op->getLoc(), rewriter.getIndexType(),
sourceDescriptor.size(rewriter, op->getLoc(), index));
upperBounds.push_back(dim);
upperBounds.push_back(getValueOrCreateConstantIndexOp(
rewriter, op->getLoc(), sizes[index]));
}

Type tokenType = typeConverter->convertType(op.getType());
Value contiguousSize;
for (auto index :
llvm::seq<size_t>(sharedNonContiguous, sourceMemRef.getRank())) {
Value dim =
getValueOrCreateConstantIndexOp(rewriter, op->getLoc(), sizes[index]);
if (!contiguousSize) {
contiguousSize = dim;
continue;
}
contiguousSize =
rewriter.create<arith::MulIOp>(op->getLoc(), contiguousSize, dim);
}
contiguousSize = typeConverter->materializeTargetConversion(
rewriter, op->getLoc(), getIndexType(), contiguousSize);
contiguousSize =
rewriter.create<LLVM::MulOp>(op->getLoc(), contiguousSize, elementSize);

Value completedToken = rewriter.create<CompletedTokenOp>(op->getLoc());
completedToken = typeConverter->materializeTargetConversion(
rewriter, op->getLoc(), tokenType, completedToken);

scf::LoopNest loopNest = scf::buildLoopNest(
rewriter, op->getLoc(), lowerBounds, upperBounds, steps, completedToken,
[&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange iterArgs) -> scf::ValueVector {
auto linearizeOffset = [&](MemRefDescriptor descriptor) {
Value offset =
rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI32Type());
for (auto [index, iv] : llvm::enumerate(ivs)) {
Value increment = rewriter.create<LLVM::MulOp>(
loc,
typeConverter->materializeTargetConversion(
builder, op->getLoc(),
typeConverter->convertType(iv.getType()), iv),
descriptor.stride(builder, loc, index));
offset = rewriter.create<LLVM::AddOp>(loc, offset, increment);
}
return offset;
};

Value sourceAdjusted = rewriter.create<LLVM::GEPOp>(
loc, source.getType(),
typeConverter->convertType(sourceMemRef.getElementType()), source,
linearizeOffset(sourceDescriptor));
Value destAdjusted = rewriter.create<LLVM::GEPOp>(
loc, dest.getType(),
typeConverter->convertType(destMemRef.getElementType()), dest,
linearizeOffset(destDescriptor));
SmallVector<OpFoldResult> offsets = ivs;
SmallVector<OpFoldResult> subSizes(sharedNonContiguous - 1,
rewriter.getIndexAttr(1));
for (unsigned i : llvm::seq<unsigned>(sharedNonContiguous - 1,
sourceMemRef.getRank())) {
offsets.push_back(rewriter.getIndexAttr(0));
subSizes.push_back(sizes[i]);
}
SmallVector<OpFoldResult> strides(sourceMemRef.getRank(),
rewriter.getIndexAttr(1));

TypedValue<MemRefType> sourceMemRefSlice =
rewriter.create<memref::SubViewOp>(loc, op.getSource(), offsets,
subSizes, strides);
TypedValue<MemRefType> destMemRefSlice =
rewriter.create<memref::SubViewOp>(loc, op.getDest(), offsets,
subSizes, strides);

auto sourceDescriptor =
MemRefDescriptor(typeConverter->materializeTargetConversion(
rewriter, op->getLoc(),
typeConverter->convertType(sourceMemRefSlice.getType()),
sourceMemRefSlice));
auto destDescriptor =
MemRefDescriptor(typeConverter->materializeTargetConversion(
rewriter, op->getLoc(),
typeConverter->convertType(destMemRefSlice.getType()),
destMemRefSlice));

Value sourceAdjusted = sourceDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(),
sourceMemRefSlice.getType());
Value destAdjusted = destDescriptor.bufferPtr(
rewriter, op->getLoc(), *getTypeConverter(),
destMemRefSlice.getType());

Value sourceStride =
sourceDescriptor.stride(builder, loc, sourceStrides.size() - 1);
sourceDescriptor.stride(builder, loc, sharedNonContiguous - 1);
sourceStride = rewriter.create<LLVM::MulOp>(
op->getLoc(), sourceStride, elementSize);
Value destStride =
destDescriptor.stride(builder, loc, destStrides.size() - 1);
destDescriptor.stride(builder, loc, sharedNonContiguous - 1);
destStride = rewriter.create<LLVM::MulOp>(op->getLoc(), destStride,
elementSize);

Value outerLoopSize =
sourceDescriptor.size(builder, loc, shape.size() - 1);
sourceDescriptor.size(builder, loc, sharedNonContiguous - 1);
return {builder
.create<LLVM::CallOp>(loc, dmaStart2DFunc,
ValueRange{
Expand All @@ -317,7 +297,10 @@ struct StartDMATransferOp2DLowering
.getResult()};
});

rewriter.replaceOp(op, loopNest.results.front());
Type tokenType = typeConverter->convertType(op.getType());
rewriter.replaceOp(
op, typeConverter->materializeTargetConversion(
rewriter, op->getLoc(), tokenType, loopNest.results.front()));
return success();
}
};
Expand Down
81 changes: 56 additions & 25 deletions codegen/tests/Conversion/ConvertSnitchToLLVM/dma_transfer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ func.func private @test4(%arg0 : memref<1x4xf32>, %arg1 : memref<1x4xf32, stride
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG0_PTR:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG0_SIZE:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG0_STRIDE_N:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG1_PTR:[[:alnum:]]+]]
Expand All @@ -67,11 +67,13 @@ func.func private @test4(%arg0 : memref<1x4xf32>, %arg1 : memref<1x4xf32, stride
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]]
func.func private @test5(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[8, 1], offset: 0>>) -> !quidditch_snitch.dma_token {
// CHECK: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32)
// CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[ELEMENT_WIDTH]], %[[ELEMENT_WIDTH]]
// CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[ARG0_STRIDE_N]], %[[ELEMENT_WIDTH]]
// CHECK-DAG: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32)
// CHECK-DAG: %[[FOUR_INDEX:.*]] = llvm.mlir.constant(4 : index)
// CHECK-DAG: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
// CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[FOUR_INDEX]], %[[ELEMENT_WIDTH]]
// CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[FOUR_INDEX]], %[[ELEMENT_WIDTH]]
// CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]]
// CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[ARG0_SIZE]])
// CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[TWO]])
%0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[8, 1], offset: 0>>
return %0 : !quidditch_snitch.dma_token
}
Expand All @@ -80,14 +82,14 @@ func.func private @test5(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, stride
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG0_PTR:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG0_SIZE:[[:alnum:]]+]]
// CHECK-SAME: %[[DIM1:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG0_STRIDE0:[[:alnum:]]+]]
// CHECK-SAME: %[[ARG0_STRIDE_N:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG1_ALIGNED_PTR:[[:alnum:]]+]]
// CHECK-SAME: %[[ARG1_PTR:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
Expand All @@ -98,27 +100,30 @@ func.func private @test6(%arg0 : memref<3x2x4xf32>, %arg1 : memref<3x2x4xf32, st
// CHECK-DAG: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32)
// CHECK-DAG: %[[ZERO32:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i32
// CHECK-DAG: %[[EIGHT:.*]] = llvm.mlir.constant(8 : {{.*}}) : i32
// CHECK-DAG: %[[SIXTEEN:.*]] = llvm.mlir.constant(16 : {{.*}}) : i32
// CHECK-DAG: %[[ONE:.*]] = llvm.mlir.constant(1 : {{.*}}) : i32
// CHECK: %[[ARG1_PTR:.*]] = llvm.getelementptr %[[ARG1_ALIGNED_PTR]][2]
// CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[ELEMENT_WIDTH]], %[[ELEMENT_WIDTH]]
// CHECK-DAG: %[[TWO:.*]] = llvm.mlir.constant(2 : {{.*}}) : i32
// CHECK-DAG: %[[THREE:.*]] = llvm.mlir.constant(3 : {{.*}}) : i32
// CHECK-DAG: %[[FOUR_INDEX:.*]] = llvm.mlir.constant(4 : index)

// CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[FOUR_INDEX]], %[[ELEMENT_WIDTH]]
// CHECK: llvm.br ^[[BB1:.*]](%[[ZERO]], %[[ZERO32]]

// CHECK: ^[[BB1]](%[[IV1:.*]]: i32, %[[IV2:.*]]: i32):
// CHECK: %[[COND:.*]] = llvm.icmp "slt" %[[IV1]], %[[ARG0_SIZE]]
// CHECK: %[[COND:.*]] = llvm.icmp "slt" %[[IV1]], %[[THREE]]
// CHECK: llvm.cond_br %[[COND]], ^[[BODY:.*]], ^[[EXIT:[[:alnum:]]+]]

// CHECK: ^[[BODY]]:
// CHECK: %[[MUL:.*]] = llvm.mul %[[IV1]], %[[ARG0_STRIDE0]]
// CHECK: %[[ARG0_OFFSET1:.*]] = llvm.add %[[MUL]], %[[ZERO32]]
// CHECK: %[[ARG0_ADJUSTED:.*]] = llvm.getelementptr %[[ARG0_PTR]][%[[ARG0_OFFSET1]]]
// CHECK: %[[MUL1:.*]] = llvm.mul %[[IV1]], %[[EIGHT]]
// CHECK: %[[MUL2:.*]] = llvm.mul %[[IV1]], %[[SIXTEEN]]
// CHECK: %[[ARG0_OFFSET1:.*]] = llvm.add %[[MUL2]], %[[TWO]]
// CHECK: %[[ARG0_ADJUSTED:.*]] = llvm.getelementptr %[[ARG0_PTR]][%[[MUL1]]]
// CHECK: %[[ARG1_ADJUSTED:.*]] = llvm.getelementptr %[[ARG1_PTR]][%[[ARG0_OFFSET1]]]

// CHECK: %[[MUL:.*]] = llvm.mul %[[IV1]], %[[ARG1_STRIDE0]]
// CHECK: %[[ARG1_OFFSET1:.*]] = llvm.add %[[MUL]], %[[ZERO32]]
// CHECK: %[[ARG1_ADJUSTED:.*]] = llvm.getelementptr %[[ARG1_PTR]][%[[ARG1_OFFSET1]]]

// CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[ARG0_STRIDE_N]], %[[ELEMENT_WIDTH]]
// CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]]
// CHECK: %[[RES:.*]] = llvm.call @snrt_dma_start_2d(%[[ARG1_ADJUSTED]], %[[ARG0_ADJUSTED]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[DIM1]])
// CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[FOUR_INDEX]], %[[ELEMENT_WIDTH]]
// CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[EIGHT]], %[[ELEMENT_WIDTH]]
// CHECK: %[[RES:.*]] = llvm.call @snrt_dma_start_2d(%[[ARG1_ADJUSTED]], %[[ARG0_ADJUSTED]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[TWO]])
// CHECK: %[[INV:.*]] = llvm.add %[[IV1]], %[[ONE]]
// CHECK: llvm.br ^[[BB1]](%[[INV]], %[[RES]]

Expand All @@ -142,12 +147,38 @@ func.func private @test6(%arg0 : memref<3x2x4xf32>, %arg1 : memref<3x2x4xf32, st
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]]
func.func private @dynamic_strides(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[?, 1], offset: 0>>) -> !quidditch_snitch.dma_token {
func.func private @dynamic_strides(%arg0 : memref<2x4xf32>, %arg1 : memref<2x4xf32, strided<[?, 1]>>) -> !quidditch_snitch.dma_token {
// CHECK-DAG: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32)
// CHECK-DAG: %[[FOUR:.*]] = llvm.mlir.constant(4 : index)
// CHECK-DAG: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
// CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[FOUR]], %[[ELEMENT_WIDTH]]
// CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[FOUR]], %[[ELEMENT_WIDTH]]
// CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]]
// CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[TWO]])
%0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[?, 1]>>
return %0 : !quidditch_snitch.dma_token
}

// CHECK-LABEL: @contigious_dynamic_inner
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG0_PTR:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG0_SIZE:[[:alnum:]]+]]
// CHECK-SAME: %[[ARG0_STRIDE_0:[[:alnum:]]+]]
// CHECK-SAME: %[[ARG0_STRIDE_N:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG1_PTR:[[:alnum:]]+]]
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %{{[[:alnum:]]+}}
// CHECK-SAME: %[[ARG1_STRIDE_N:[[:alnum:]]+]]
func.func private @contigious_dynamic_inner(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32, strided<[?, 1]>>) -> !quidditch_snitch.dma_token {
// CHECK: %[[ELEMENT_WIDTH:.*]] = llvm.mlir.constant(4 : i32)
// CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[ELEMENT_WIDTH]], %[[ELEMENT_WIDTH]]
// CHECK: %[[INNER_SIZE:.*]] = llvm.mul %[[ARG0_STRIDE_0]], %[[ELEMENT_WIDTH]]
// CHECK: %[[ARG0_STRIDE:.*]] = llvm.mul %[[ARG0_STRIDE_N]], %[[ELEMENT_WIDTH]]
// CHECK: %[[ARG1_STRIDE:.*]] = llvm.mul %[[ARG1_STRIDE_N]], %[[ELEMENT_WIDTH]]
// CHECK: llvm.call @snrt_dma_start_2d(%[[ARG1_PTR]], %[[ARG0_PTR]], %[[INNER_SIZE]], %[[ARG1_STRIDE]], %[[ARG0_STRIDE]], %[[ARG0_SIZE]])
%0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<2x4xf32> to %arg1 : memref<2x4xf32, strided<[?, 1], offset: 0>>
%0 = quidditch_snitch.start_dma_transfer from %arg0 : memref<?x?xf32> to %arg1 : memref<?x?xf32, strided<[?, 1]>>
return %0 : !quidditch_snitch.dma_token
}

0 comments on commit 500eddc

Please sign in to comment.