From 812da1a034267a79594daabe9c9e0e516ba35760 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02@gmail.com> Date: Mon, 19 Aug 2024 11:22:36 +0200 Subject: [PATCH] [quidditch_snitch] Add padding capabilities to `start_tensor_copy` (#116) We occasionally encounter shapes that are challenging to tile due to their prime factors involved. Attempting to distribute these (e.g. to compute cores or vector lanes) when the number of required tiles is not a factor of the dimension leads to generating dynamic dimensions which the microkernel compilation is unable to deal with. Similarly, once we are on `f32`, we are required to vectorize the kernel and have a restriction that the tile size of e.g. a matvec is a multiple of 4, 8 etc. This PR therefore introduces optional padding to the `start_dma_transfer` op that can be added at the end of each tensor dimension. When tiled, the padding can be chosen to guarantee that a tensor is always of a given static shape, solving the issue noted above. For now, the value used for padding is always zero which works for any matmul, elementwise operation and convolution. Note that the padding option is not yet used in the pipeline but will be lowered to from `tensor.pad` operations in a future PR. --- .../Dialect/Snitch/IR/QuidditchSnitchOps.cpp | 124 +++++++++++++++--- .../Dialect/Snitch/IR/QuidditchSnitchOps.td | 55 ++++++-- .../Dialect/Snitch/IR/bufferization.mlir | 54 ++++++-- .../Dialect/Snitch/IR/canonicalization.mlir | 27 +++- .../tests/Dialect/Snitch/IR/roundtrip.mlir | 2 +- .../Transforms/pipeline-copy-compute.mlir | 24 ++-- .../Transforms/promote-operands-to-l1.mlir | 12 +- 7 files changed, 242 insertions(+), 56 deletions(-) diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp index b45b24f..a14b822 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp @@ -1,6 +1,7 @@ #include "QuidditchSnitchOps.h" #include "llvm/ADT/ScopeExit.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -399,34 +400,98 @@ void MicrokernelFenceOp::replaceWithNoop(RewriterBase &rewriter) { // StartTensorCopyOp //===----------------------------------------------------------------------===// +LogicalResult StartTensorCopyOp::verify() { + if (getStaticHighPadAttr()) + if (getStaticHighPadAttr().size() != getCopy().getType().getRank()) + return emitOpError("expected padding number for every dimension"); + + unsigned numDynamicPads = llvm::count( + getStaticHighPad().value_or(std::nullopt), ShapedType::kDynamic); + if (numDynamicPads != getHighPad().size()) + return emitOpError("expected ") + << numDynamicPads << " dynamic padding values"; + + return success(); +} + LogicalResult StartTensorCopyOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { + if (hasPadding()) { + // Remove noop padding. + if (llvm::all_of(getStaticHighPadAttr().asArrayRef(), + [](int64_t value) { return value == 0; })) { + removeStaticHighPadAttr(); + return success(); + } + + // Fold dynamic indices with constant values into the static list. + { + bool changed = false; + SmallVector<int64_t> padding = + llvm::to_vector(getStaticHighPadAttr().asArrayRef()); + unsigned dynamicIndex = 0; + for (int64_t &value : padding) { + if (!ShapedType::isDynamic(value)) + continue; + + if (auto integer = dyn_cast_or_null<IntegerAttr>( + adaptor.getHighPad()[dynamicIndex])) { + value = integer.getValue().getZExtValue(); + getHighPadMutable().erase(dynamicIndex); + changed = true; + } else { + dynamicIndex++; + } + } + if (changed) { + setStaticHighPad(padding); + return success(); + } + } + } + auto waitOp = getCopy().getDefiningOp<WaitForTensorCopyOp>(); if (!waitOp) return failure(); auto copyOp = waitOp.getTransferTensor().getDefiningOp<StartTensorCopyOp>(); if (!copyOp) return failure(); + if (copyOp.getStaticHighPadAttr() != getStaticHighPadAttr() || + copyOp.getHighPad() != getHighPad()) + return failure(); results.emplace_back(waitOp); results.emplace_back(CompletedTokenAttr::get(getContext())); return success(); } +SmallVector<OpFoldResult> StartTensorCopyOp::getMixedHighPad() { + Builder builder(getContext()); + if (!hasPadding()) + return SmallVector<OpFoldResult>(getResult().getType().getRank(), + builder.getIndexAttr(0)); + + return getMixedValues(getStaticHighPadAttr().asArrayRef(), getHighPad(), + builder); +} + //===----------------------------------------------------------------------===// // StartTensorCopyOp::BufferizableOpInterface //===----------------------------------------------------------------------===// -/// Returns whether 'copy' is already in L1 memory. +/// Returns whether the allocation can be elided entirely. /// Returns an empty optional if it was not possible to determine. -static std::optional<bool> -isInL1Memory(Value copy, - const bufferization::BufferizationOptions &options = {}, - SmallVector<Value> *invocationStack = nullptr) { +std::optional<bool> StartTensorCopyOp::elidesAllocation( + const bufferization::BufferizationOptions &options, + SmallVector<Value> *invocationStack) { + // Padding cannot be elided in general, even if the copied buffer is in L1. + if (hasPadding()) + return false; + FailureOr<BaseMemRefType> copyType = invocationStack - ? bufferization::getBufferType(copy, options, *invocationStack) - : bufferization::getBufferType(copy, options); + ? bufferization::getBufferType(getCopy(), options, *invocationStack) + : bufferization::getBufferType(getCopy(), options); if (failed(copyType)) return std::nullopt; @@ -437,7 +502,7 @@ bool StartTensorCopyOp::resultBufferizesToMemoryWrite( OpResult opResult, const bufferization::AnalysisState &state) { assert(opResult == getResult() && "no other result"); - std::optional<bool> matches = isInL1Memory(getCopy(), state.getOptions()); + std::optional<bool> matches = elidesAllocation(state.getOptions()); // Conservative answer. if (!matches) return true; @@ -451,7 +516,7 @@ bool StartTensorCopyOp::bufferizesToMemoryRead( OpOperand &opOperand, const bufferization::AnalysisState &state) { assert(opOperand == getCopyMutable() && "have only one operand"); - std::optional<bool> result = isInL1Memory(getCopy(), state.getOptions()); + std::optional<bool> result = elidesAllocation(state.getOptions()); // Conservative answer. if (!result) return true; @@ -472,7 +537,7 @@ AliasingValueList StartTensorCopyOp::getAliasingValues( OpOperand &opOperand, const bufferization::AnalysisState &state) { assert(opOperand == getCopyMutable() && "have only one operand"); - std::optional<bool> result = isInL1Memory(getCopy(), state.getOptions()); + std::optional<bool> result = elidesAllocation(state.getOptions()); if (!result) // Assume the worst case. return {{getResult(), BufferRelation::Equivalent, /*isDefinite=*/false}}; @@ -488,7 +553,7 @@ AliasingValueList StartTensorCopyOp::getAliasingValues( bool StartTensorCopyOp::bufferizesToAllocation(Value value) { assert(value == getResult() && "have only one result"); - if (isInL1Memory(getCopy()) == true) + if (elidesAllocation() == true) return false; // True is the conservative reply, according to the docs. @@ -503,7 +568,7 @@ StartTensorCopyOp::getBufferType(Value value, bool contained = llvm::is_contained(invocationStack, value); if (!contained) - if (isInL1Memory(getCopy(), options, &invocationStack) == true) + if (elidesAllocation(options, &invocationStack) == true) return bufferization::getBufferType(getCopy(), options, invocationStack); // Unless contained in the invocation stack (where we are free to impose the @@ -530,7 +595,7 @@ StartTensorCopyOp::bufferize(RewriterBase &rewriter, if (failed(copyBuffer)) return failure(); - std::optional<bool> result = isInL1Memory(getCopy(), options); + std::optional<bool> result = elidesAllocation(options); if (!result) return failure(); @@ -546,12 +611,20 @@ StartTensorCopyOp::bufferize(RewriterBase &rewriter, if (failed(allocType)) return failure(); + SmallVector<OpFoldResult> copyBufferSizes = + memref::getMixedSizes(rewriter, getLoc(), *copyBuffer); + + // Compute the dynamic dimensions for the allocation. SmallVector<Value> dynamicDims; - for (auto [index, shape] : llvm::enumerate(allocType->getShape())) { + for (auto [index, shape, pad] : + llvm::enumerate(allocType->getShape(), getMixedHighPad())) { if (!ShapedType::isDynamic(shape)) continue; - dynamicDims.push_back( - rewriter.create<memref::DimOp>(getLoc(), *copyBuffer, index)); + + dynamicDims.push_back(affine::makeComposedAffineApply( + rewriter, getLoc(), + rewriter.getAffineDimExpr(0) + rewriter.getAffineDimExpr(1), + ArrayRef<OpFoldResult>{copyBufferSizes[index], pad})); } FailureOr<Value> alloc = options.createAlloc( @@ -560,8 +633,25 @@ StartTensorCopyOp::bufferize(RewriterBase &rewriter, if (failed(alloc)) return failure(); + // Zero out the entire buffer prior to overwriting it with the copied values. + // TODO: This could be optimized to only zero regions that won't be filled + // with the copied values at the cost of 2^rank transfers instead of two. + if (hasPadding()) + rewriter.create<StartZeroMemTransferOp>(getLoc(), *alloc); + + // Subview into the original memory without any padding. + // As we only add padding at the end of the dimensions, the offsets are always + // zero. + Value destination = rewriter.create<memref::SubViewOp>( + getLoc(), *alloc, + /*offsets=*/ + SmallVector<OpFoldResult>(allocType->getRank(), rewriter.getIndexAttr(0)), + copyBufferSizes, + /*strides=*/ + SmallVector<OpFoldResult>(allocType->getRank(), + rewriter.getIndexAttr(1))); Value token = - rewriter.create<StartDMATransferOp>(getLoc(), *copyBuffer, *alloc); + rewriter.create<StartDMATransferOp>(getLoc(), *copyBuffer, destination); // Replace op. replaceOpWithBufferizedValues(rewriter, getOperation(), {*alloc, token}); diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td index 333cb39..e634b45 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td @@ -175,24 +175,32 @@ def QuidditchSnitch_MicrokernelFenceOp : QuidditchSnitch_Op<"microkernel_fence", } def QuidditchSnitch_StartTensorCopyOp : QuidditchSnitch_Op<"start_tensor_copy", - [AllTypesMatch<["copy", "result"]>, Pure, + [Pure, AllRanksMatch<["copy", "result"]>, DeclareOpInterfaceMethods<BufferizableOpInterface, ["resultBufferizesToMemoryWrite", "bufferizesToMemoryRead", "bufferizesToMemoryWrite", "getAliasingValues", "getBufferType", "bufferize", "bufferizesToAllocation"]>]> { let description = [{ - Operation starting a copy of a tensor to L1 memory space returning it as - a new tensor. - The contained values of the tensor in an unspecified state. + Operation starting a copy of a tensor to L1 memory space, optionally adding + padding and returning it as a new tensor. + The contained values of the resulting tensor is in an unspecified state. See `wait_for_tensor_copy` to transform the tensor value into a state equal to `$copy`. - This operation is a noop if `$copy` and `$result` are already in L1 and - bufferization can elide the copy. + The operation may optionally add padding at the end of each dimension of + the tensor. Zero is used as the padding value. + The dimensions of the result tensor are computed using + `dims(copy)[i] + high_pad[i]`. + + This operation is a noop if `$copy` is already in L1, no padding is added, + and bufferization can elide the copy. }]; - let arguments = (ins AnyRankedTensor:$copy); + let arguments = (ins AnyRankedTensor:$copy, + Variadic<Index>:$high_pad, + OptionalAttr<DenseI64ArrayAttr>:$static_high_pad + ); let results = (outs AnyRankedTensor:$result, @@ -200,14 +208,41 @@ def QuidditchSnitch_StartTensorCopyOp : QuidditchSnitch_Op<"start_tensor_copy", ); let assemblyFormat = [{ - $copy `to` `L1` `:` type($copy) attr-dict + $copy `to` `L1` + ( `pad` `with` `zero` `to` + custom<DynamicIndexList>($high_pad, $static_high_pad)^)? + `:` type($copy) `->` type($result) attr-dict + }]; + + let builders = [ + OpBuilder<(ins "mlir::Value":$copy), [{ + build($_builder, $_state, copy.getType(), + $_builder.getType<DMATokenType>(), copy, + /*high_pad=*/mlir::ValueRange(), /*static_high_pad=*/nullptr); + }]> + ]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + private: + std::optional<bool> + elidesAllocation(const mlir::bufferization::BufferizationOptions &options = {}, + llvm::SmallVector<mlir::Value> *invocationStack = nullptr); + public: + + bool hasPadding() { + return static_cast<bool>(getStaticHighPadAttr()); + } + + llvm::SmallVector<mlir::OpFoldResult> getMixedHighPad(); }]; let hasFolder = 1; } def QuidditchSnitch_WaitForTensorCopyOp : QuidditchSnitch_Op<"wait_for_tensor_copy", - [AllTypesMatch<["transfer_tensor", "result", "copy"]>, Pure, + [AllTypesMatch<["transfer_tensor", "result"]>, Pure, DeclareOpInterfaceMethods<BufferizableOpInterface, ["bufferizesToMemoryRead", "bufferizesToMemoryWrite", "getAliasingValues", "bufferize", "mustBufferizeInPlace", "isNotConflicting"]>]> { @@ -240,7 +275,7 @@ def QuidditchSnitch_WaitForTensorCopyOp : QuidditchSnitch_Op<"wait_for_tensor_co ); let assemblyFormat = [{ - `of` $copy `to` $transfer_tensor `using` $token `:` type($transfer_tensor) attr-dict + `of` $copy `:` type($copy) `to` $transfer_tensor `using` $token `->` type($transfer_tensor) attr-dict }]; let hasFolder = 1; diff --git a/codegen/tests/Dialect/Snitch/IR/bufferization.mlir b/codegen/tests/Dialect/Snitch/IR/bufferization.mlir index e38d1ae..7aa9088 100644 --- a/codegen/tests/Dialect/Snitch/IR/bufferization.mlir +++ b/codegen/tests/Dialect/Snitch/IR/bufferization.mlir @@ -1,15 +1,19 @@ // RUN: quidditch-opt %s --one-shot-bufferize | FileCheck %s +// CHECK: #[[$MAP2:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> + // CHECK: func @copy_l1_buffer( func.func @copy_l1_buffer(%arg0 : tensor<32xf32>) -> (tensor<32xf32>, !quidditch_snitch.dma_token) { // CHECK: %[[ARG0:.*]] = bufferization.to_memref // CHECK: %[[ALLOC:.*]] = memref.alloc() // CHECK-SAME: : memref<32xf32, #quidditch_snitch.l1_encoding> + // CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]] + // CHECK-SAME: to memref<32xf32, strided<[1]>, #quidditch_snitch.l1_encoding> // CHECK: %[[TOKEN:.*]] = quidditch_snitch.start_dma_transfer from %[[ARG0]] - // CHECK-SAME: to %[[ALLOC]] + // CHECK-SAME: to %[[SUBVIEW]] // CHECK: %[[R:.*]] = bufferization.to_tensor %[[ALLOC]] - %r, %token = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<32xf32> + %r, %token = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<32xf32> -> tensor<32xf32> // CHECK: return %[[R]], %[[TOKEN]] return %r, %token : tensor<32xf32>, !quidditch_snitch.dma_token } @@ -18,10 +22,10 @@ func.func @copy_l1_buffer(%arg0 : tensor<32xf32>) -> (tensor<32xf32>, !quidditch func.func @copy_l1_buffer_elided(%arg0 : tensor<32xf32>) -> tensor<32xf32> { // CHECK: memref.alloc() // CHECK-NOT: memref.alloc() - %r:2 = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<32xf32> - %r2 = quidditch_snitch.wait_for_tensor_copy of %arg0 to %r#0 using %r#1 : tensor<32xf32> - %r3:2 = quidditch_snitch.start_tensor_copy %r2 to L1 : tensor<32xf32> - %r4 = quidditch_snitch.wait_for_tensor_copy of %r2 to %r3#0 using %r3#1 : tensor<32xf32> + %r:2 = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<32xf32> -> tensor<32xf32> + %r2 = quidditch_snitch.wait_for_tensor_copy of %arg0 : tensor<32xf32> to %r#0 using %r#1 -> tensor<32xf32> + %r3:2 = quidditch_snitch.start_tensor_copy %r2 to L1 : tensor<32xf32> -> tensor<32xf32> + %r4 = quidditch_snitch.wait_for_tensor_copy of %r2 : tensor<32xf32> to %r3#0 using %r3#1 -> tensor<32xf32> // CHECK: return return %r4 : tensor<32xf32> } @@ -31,7 +35,7 @@ func.func @copy_l1_buffer_alloca_elided() -> tensor<32xf32> { // CHECK: memref.alloc() // CHECK-NOT: memref.alloc() %r = bufferization.alloc_tensor() {memory_space = #quidditch_snitch.l1_encoding} : tensor<32xf32> - %r2:2 = quidditch_snitch.start_tensor_copy %r to L1 : tensor<32xf32> + %r2:2 = quidditch_snitch.start_tensor_copy %r to L1 : tensor<32xf32> -> tensor<32xf32> // CHECK: return return %r2#0 : tensor<32xf32> } @@ -42,7 +46,7 @@ func.func @scf_for_copy_l1_buffer() -> tensor<32xf32> { %c1 = arith.constant 1 : index // CHECK: %[[MEMREF:.*]] = memref.alloc %r = bufferization.alloc_tensor() {memory_space = #quidditch_snitch.l1_encoding} : tensor<32xf32> - %r2:2 = quidditch_snitch.start_tensor_copy %r to L1 : tensor<32xf32> + %r2:2 = quidditch_snitch.start_tensor_copy %r to L1 : tensor<32xf32> -> tensor<32xf32> // CHECK-NEXT: quidditch_snitch.completed_token // CHECK-NEXT: %[[R:.*]] = scf.for // CHECK-SAME: iter_args(%[[ITER:.*]] = %[[MEMREF]]) @@ -50,7 +54,7 @@ func.func @scf_for_copy_l1_buffer() -> tensor<32xf32> { // CHECK-NEXT: scf.yield %[[ITER]] // CHECK: bufferization.to_tensor %[[R]] %r3 = scf.for %i = %c0 to %c1 step %c1 iter_args(%iter = %r2#0) -> (tensor<32xf32>) { - %r4:2 = quidditch_snitch.start_tensor_copy %iter to L1 : tensor<32xf32> + %r4:2 = quidditch_snitch.start_tensor_copy %iter to L1 : tensor<32xf32> -> tensor<32xf32> scf.yield %r4#0 : tensor<32xf32> } return %r3 : tensor<32xf32> @@ -60,13 +64,16 @@ func.func @scf_for_copy_l1_buffer() -> tensor<32xf32> { func.func @copy_l1_buffer_dynamic_dims(%arg0 : tensor<?xf32>) -> tensor<?xf32> { // CHECK: %[[ARG0:.*]] = bufferization.to_memref // CHECK: %[[ZERO:.*]] = arith.constant 0 - // CHECK: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[ZERO]] + // CHECK: %[[DIM_IN:.*]] = memref.dim %[[ARG0]], %[[ZERO]] + // CHECK: %[[DIM:.*]] = affine.apply #{{.*}}()[%[[DIM_IN]]] // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) // CHECK-SAME: : memref<?xf32, #quidditch_snitch.l1_encoding> + // CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]] + // CHECK-SAME: to memref<?xf32, strided<[1]>, #quidditch_snitch.l1_encoding> // CHECK: quidditch_snitch.start_dma_transfer from %[[ARG0]] - // CHECK-SAME: to %[[ALLOC]] + // CHECK-SAME: to %[[SUBVIEW]] // CHECK: %[[R:.*]] = bufferization.to_tensor %[[ALLOC]] - %r:2 = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<?xf32> + %r:2 = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<?xf32> -> tensor<?xf32> // CHECK: return %[[R]] return %r#0 : tensor<?xf32> } @@ -135,3 +142,26 @@ func.func @sync_tensor() -> tensor<32xf32> { // CHECK: return %[[R]] return %r : tensor<32xf32> } + +// CHECK-LABEL: @tensor_copy_pad +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +// CHECK-SAME: %[[PAD0:[[:alnum:]]+]] +// CHECK-SAME: %[[PAD1:[[:alnum:]]+]] +func.func @tensor_copy_pad(%arg0 : tensor<?x?xf32>, %pad0 : index, %pad1 : index) -> (tensor<?x?xf32>, !quidditch_snitch.dma_token) { + // CHECK: %[[COPY:.*]] = bufferization.to_memref %[[ARG0]] + // CHECK: %[[ZERO:.*]] = arith.constant 0 + // CHECK: %[[DIM0:.*]] = memref.dim %[[COPY]], %[[ZERO]] + // CHECK: %[[ONE:.*]] = arith.constant 1 + // CHECK: %[[DIM1:.*]] = memref.dim %[[COPY]], %[[ONE]] + // CHECK: %[[NEW_DIM0:.*]] = affine.apply #[[$MAP2]]()[%[[DIM0]], %[[PAD0]]] + // CHECK: %[[NEW_DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[DIM1]], %[[PAD1]]] + // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[NEW_DIM0]], %[[NEW_DIM1]]) + // CHECK: start_zero_mem_transfer %[[ALLOC]] + // CHECK: %[[UNPADDED:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[DIM0]], %[[DIM1]]] [1, 1] + // CHECK: %[[TOKEN:.*]] = quidditch_snitch.start_dma_transfer from %[[COPY]] + // CHECK-SAME: to %[[UNPADDED]] + %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 pad with zero to [%pad0, %pad1] : tensor<?x?xf32> -> tensor<?x?xf32> + // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] + // CHECK: return %[[TENSOR]], %[[TOKEN]] + return %r, %t : tensor<?x?xf32>, !quidditch_snitch.dma_token +} diff --git a/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir b/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir index 31911f6..d01e08b 100644 --- a/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir +++ b/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir @@ -98,7 +98,7 @@ func.func @pipeline_invariant(%tensor : tensor<?xf32>) { func.func @tensor_wait_gets_removed(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> { // CHECK-NEXT: return %[[ARG1]] %t = quidditch_snitch.completed_token - %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 to %arg1 using %t : tensor<?xf32> + %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 : tensor<?xf32> to %arg1 using %t -> tensor<?xf32> return %0 : tensor<?xf32> } @@ -107,8 +107,27 @@ func.func @tensor_noop_transfer(%arg0 : tensor<?xf32>) -> (tensor<?xf32>, !quidd // CHECK: %[[T:.*]] = quidditch_snitch.completed_token // CHECK: %[[R:.*]] = quidditch_snitch.wait_for_tensor_copy // CHECK-NEXT: return %[[R]], %[[T]] - %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<?xf32> - %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 to %r using %t : tensor<?xf32> - %r2, %t2 = quidditch_snitch.start_tensor_copy %0 to L1 : tensor<?xf32> + %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<?xf32> -> tensor<?xf32> + %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 : tensor<?xf32> to %r using %t -> tensor<?xf32> + %r2, %t2 = quidditch_snitch.start_tensor_copy %0 to L1 : tensor<?xf32> -> tensor<?xf32> return %r2, %t2 : tensor<?xf32>, !quidditch_snitch.dma_token } + +// CHECK-LABEL: @tensor_noop_pad +func.func @tensor_noop_pad(%arg0 : tensor<?xf32>) -> (tensor<?xf32>, !quidditch_snitch.dma_token) { + // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy + // CHECK-NOT: pad with + %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 pad with zero to [0] : tensor<?xf32> -> tensor<?xf32> + // CHECK-NEXT: return %[[R]], %[[T]] + return %r, %t : tensor<?xf32>, !quidditch_snitch.dma_token +} + +// CHECK-LABEL: @tensor_pad_constant +func.func @tensor_pad_constant(%arg0 : tensor<?xf32>) -> (tensor<?xf32>, !quidditch_snitch.dma_token) { + %zero = arith.constant 0 : index + // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy + // CHECK-NOT: pad with + %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 pad with zero to [%zero] : tensor<?xf32> -> tensor<?xf32> + // CHECK-NEXT: return %[[R]], %[[T]] + return %r, %t : tensor<?xf32>, !quidditch_snitch.dma_token +} diff --git a/codegen/tests/Dialect/Snitch/IR/roundtrip.mlir b/codegen/tests/Dialect/Snitch/IR/roundtrip.mlir index 8ffcf68..0f22a4d 100644 --- a/codegen/tests/Dialect/Snitch/IR/roundtrip.mlir +++ b/codegen/tests/Dialect/Snitch/IR/roundtrip.mlir @@ -10,6 +10,6 @@ func.func @test(%arg0 : memref<f64>) { } func.func @test3(%arg0 : tensor<?x4xf64>) -> (tensor<?x4xf64>, !quidditch_snitch.dma_token) { - %0:2 = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<?x4xf64> + %0:2 = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<?x4xf64> -> tensor<?x4xf64> return %0#0, %0#1 : tensor<?x4xf64>, !quidditch_snitch.dma_token } diff --git a/codegen/tests/Dialect/Snitch/Transforms/pipeline-copy-compute.mlir b/codegen/tests/Dialect/Snitch/Transforms/pipeline-copy-compute.mlir index c17740d..b9c70aa 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/pipeline-copy-compute.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/pipeline-copy-compute.mlir @@ -25,12 +25,12 @@ func.func @test(%arg0: index, %extracted_slice : tensor<1x100xf64>, %14 : tensor %extracted_slice_6 = tensor.extract_slice %14[%arg2, %arg0] [40, 100] [1, 1] : tensor<1200x400xf64> to tensor<40x100xf64> %extracted_slice_7 = tensor.extract_slice %arg3[0, %arg2] [1, 40] [1, 1] : tensor<1x1200xf64> to tensor<1x40xf64> - %result_8, %token_9 = quidditch_snitch.start_tensor_copy %extracted_slice to L1 : tensor<1x100xf64> - %25 = quidditch_snitch.wait_for_tensor_copy of %extracted_slice to %result_8 using %token_9 : tensor<1x100xf64> - %result_10, %token_11 = quidditch_snitch.start_tensor_copy %extracted_slice_6 to L1 : tensor<40x100xf64> - %26 = quidditch_snitch.wait_for_tensor_copy of %extracted_slice_6 to %result_10 using %token_11 : tensor<40x100xf64> - %result_12, %token_13 = quidditch_snitch.start_tensor_copy %extracted_slice_7 to L1 : tensor<1x40xf64> - %27 = quidditch_snitch.wait_for_tensor_copy of %extracted_slice_7 to %result_12 using %token_13 : tensor<1x40xf64> + %result_8, %token_9 = quidditch_snitch.start_tensor_copy %extracted_slice to L1 : tensor<1x100xf64> -> tensor<1x100xf64> + %25 = quidditch_snitch.wait_for_tensor_copy of %extracted_slice : tensor<1x100xf64> to %result_8 using %token_9 -> tensor<1x100xf64> + %result_10, %token_11 = quidditch_snitch.start_tensor_copy %extracted_slice_6 to L1 : tensor<40x100xf64> -> tensor<40x100xf64> + %26 = quidditch_snitch.wait_for_tensor_copy of %extracted_slice_6 : tensor<40x100xf64> to %result_10 using %token_11 -> tensor<40x100xf64> + %result_12, %token_13 = quidditch_snitch.start_tensor_copy %extracted_slice_7 to L1 : tensor<1x40xf64> -> tensor<1x40xf64> + %27 = quidditch_snitch.wait_for_tensor_copy of %extracted_slice_7 : tensor<1x40xf64> to %result_12 using %token_13 -> tensor<1x40xf64> // CHECK: ^{{.*}}( // CHECK-SAME: %[[IV:[[:alnum:]]+]] @@ -43,9 +43,15 @@ func.func @test(%arg0: index, %extracted_slice : tensor<1x100xf64>, %14 : tensor // CHECK-SAME: %[[SLICE2:[[:alnum:]]+]] // CHECK-SAME: %[[RESULT2:[[:alnum:]]+]] // CHECK-SAME: %[[TOKEN2:[[:alnum:]]+]] - // CHECK: %[[OPA:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[ARG1]] to %[[RESULT0]] using %[[TOKEN0]] - // CHECK: %[[OPB:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[SLICE1]] to %[[RESULT1]] using %[[TOKEN1]] - // CHECK: %[[OPC:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[SLICE2]] to %[[RESULT2]] using %[[TOKEN2]] + // CHECK: %[[OPA:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[ARG1]] + // CHECK-SAME: to %[[RESULT0]] + // CHECK-SAME: using %[[TOKEN0]] + // CHECK: %[[OPB:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[SLICE1]] + // CHECK-SAME: to %[[RESULT1]] + // CHECK-SAME: using %[[TOKEN1]] + // CHECK: %[[OPC:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[SLICE2]] + // CHECK-SAME: to %[[RESULT2]] + // CHECK-SAME: using %[[TOKEN2]] // CHECK: %[[RES:.*]] = linalg.matmul_transpose_b // CHECK-SAME: ins(%[[OPA]], %[[OPB]] : // CHECK-SAME: outs(%[[OPC]] : diff --git a/codegen/tests/Dialect/Snitch/Transforms/promote-operands-to-l1.mlir b/codegen/tests/Dialect/Snitch/Transforms/promote-operands-to-l1.mlir index e46c11d..9822084 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/promote-operands-to-l1.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/promote-operands-to-l1.mlir @@ -7,11 +7,17 @@ func.func @test(%a : tensor<32x32xf32>, %b : tensor<32x32xf32>) -> tensor<32x32x // CHECK: %[[E:.*]] = bufferization.alloc_tensor %e = bufferization.alloc_tensor() : tensor<32x32xf32> // CHECK: %[[A1:.*]], %[[TOKEN:.*]] = quidditch_snitch.start_tensor_copy %[[A]] to L1 - // CHECK: %[[A2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[A]] to %[[A1]] using %[[TOKEN]] + // CHECK: %[[A2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[A]] + // CHECK-SAME: to %[[A1]] + // CEHCK-SAME: using %[[TOKEN]] // CHECK: %[[B1:.*]], %[[TOKEN:.*]] = quidditch_snitch.start_tensor_copy %[[B]] to L1 - // CHECK: %[[B2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[B]] to %[[B1]] using %[[TOKEN]] + // CHECK: %[[B2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[B]] + // CHECK-SAME: to %[[B1]] + // CHECK-SAME: using %[[TOKEN]] // CHECK: %[[E1:.*]], %[[TOKEN:.*]] = quidditch_snitch.start_tensor_copy %[[E]] to L1 - // CHECK: %[[E2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[E]] to %[[E1]] using %[[TOKEN]] + // CHECK: %[[E2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[E]] + // CHECK-SAME: to %[[E1]] + // CHECK-SAME: using %[[TOKEN]] // CHECK: linalg.matmul ins(%[[A2]], %[[B2]] : {{.*}}) outs(%[[E2]] : {{.*}}) %r = linalg.matmul ins(%a, %b : tensor<32x32xf32>, tensor<32x32xf32>) outs(%e : tensor<32x32xf32>) -> tensor<32x32xf32> return %r : tensor<32x32xf32>