From 3ac471649acbe49822ccb9071220bffd3e15c519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Mon, 29 Jul 2024 16:28:39 +0200 Subject: [PATCH] [LowerL1Allocations] Support non-identity strides in L1 `memref.alloca` (#104) The previous lowering would crash in this case as `memref.view` op does not support a non-identity result `memref`. This PR fixes the lowering by first creating a `view` into a buffer with enough elements to support a given layout (i.e. includes padding) before `reinterpret_cast`ing to the original layout of the `alloca`. --- .../Snitch/Transforms/LowerL1Allocations.cpp | 44 ++++++++++++++++--- .../Transforms/lower-l1-allocations.mlir | 20 ++++++--- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/LowerL1Allocations.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/LowerL1Allocations.cpp index e799844..61f61f3 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/LowerL1Allocations.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/LowerL1Allocations.cpp @@ -63,18 +63,48 @@ void LowerL1Allocations::runOnOperation() { auto byteShift = builder.create(allocOp.getLoc(), offset); - // Get rid of the memory space at this point in the pipeline. - auto viewOp = builder.create( + + // We do not support anything but a zero offset right now. + [[maybe_unused]] int64_t ignoredOffset; + SmallVector strides; + if (failed(getStridesAndOffset(memRefType, strides, ignoredOffset))) { + allocOp->emitOpError( + "Cannot lower MemRef in L1 memory with a non-strided layout"); + signalPassFailure(); + return; + } + if (ignoredOffset != 0) { + allocOp->emitOpError( + "Cannot lower MemRef in L1 memory with a non-zero offset"); + signalPassFailure(); + return; + } + + // Compute how many elements we need to allocate to support the memory + // layout. This may contain padding elements due to the strides. + // Compute this via the linearized access of the last element + 1. + int64_t allocElements = 1; + for (auto [stride, shape] : llvm::zip_equal(strides, memRefType.getShape())) + allocElements += stride * (shape - 1); + + // First, allocate one large contiguous element memref. + // Get rid of the memory space at this point as well. + Value view = builder.create( + allocOp.getLoc(), + MemRefType::get({allocElements}, memRefType.getElementType()), l1Memory, + byteShift, + /*sizes=*/ValueRange()); + + // Reinterpret cast the view with the actual shape and strides. + view = builder.create( allocOp.getLoc(), MemRefType::get(memRefType.getShape(), memRefType.getElementType(), memRefType.getLayout()), - l1Memory, byteShift, - /*sizes=*/ValueRange()); - allocOp->replaceAllUsesWith(viewOp); + view, 0, memRefType.getShape(), strides); + allocOp.replaceAllUsesWith(view); uint64_t memRefSize = llvm::divideCeil(bitWidth, 8); - for (uint64_t size : memRefType.getShape()) - memRefSize *= size; + memRefSize *= allocElements; offset += memRefSize; if (offset >= l1MemoryBytes) { diff --git a/codegen/tests/Dialect/Snitch/Transforms/lower-l1-allocations.mlir b/codegen/tests/Dialect/Snitch/Transforms/lower-l1-allocations.mlir index abadb01..7d4097d 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/lower-l1-allocations.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/lower-l1-allocations.mlir @@ -1,16 +1,24 @@ // RUN: quidditch-opt %s -p "builtin.module(func.func(quidditch-lower-l1-allocations))" | FileCheck %s // CHECK-LABEL: @test() -// CHECK-SAME: -> (memref<32xf32>, memref<64xf64>) +// CHECK-SAME: -> (memref<32xf32>, memref<2x64xf64, strided<[65, 1]>>) func.func @test() -> (memref<32xf32, #quidditch_snitch.l1_encoding>, - memref<64xf64, #quidditch_snitch.l1_encoding>) { + memref<2x64xf64, strided<[65, 1]>, #quidditch_snitch.l1_encoding>) { // CHECK: %[[VIEW:.*]] = quidditch_snitch.l1_memory_view // CHECK: %[[OFFSET:.*]] = arith.constant 0 - // CHECK: %[[ALLOCA0:.*]] = memref.view %[[VIEW]][%[[OFFSET]]][] : memref<{{.*}}xi8> to memref<32xf32> + // CHECK: %[[VIEW0:.*]] = memref.view %[[VIEW]][%[[OFFSET]]][] : memref<{{.*}}xi8> to memref<32xf32> + // CHECK: %[[ALLOCA0:.*]] = memref.reinterpret_cast %[[VIEW0]] + // CHECK-SAME: offset: [0] + // CHECK-SAME: sizes: [32] + // CHECK-SAME: strides: [1] %0 = memref.alloca() : memref<32xf32, #quidditch_snitch.l1_encoding> // CHECK: %[[OFFSET:.*]] = arith.constant 128 - // CHECK: %[[ALLOCA1:.*]] = memref.view %[[VIEW]][%[[OFFSET]]][] : memref<{{.*}}xi8> to memref<64xf64> - %1 = memref.alloca() : memref<64xf64, #quidditch_snitch.l1_encoding> + // CHECK: %[[VIEW1:.*]] = memref.view %[[VIEW]][%[[OFFSET]]][] : memref<{{.*}}xi8> to memref<129xf64> + // CHECK: %[[ALLOCA1:.*]] = memref.reinterpret_cast %[[VIEW1]] + // CHECK-SAME: offset: [0] + // CHECK-SAME: sizes: [2, 64] + // CHECK-SAME: strides: [65, 1] + %1 = memref.alloca() : memref<2x64xf64, strided<[65, 1]>, #quidditch_snitch.l1_encoding> // CHECK: return %[[ALLOCA0]], %[[ALLOCA1]] - return %0, %1 : memref<32xf32, #quidditch_snitch.l1_encoding>, memref<64xf64, #quidditch_snitch.l1_encoding> + return %0, %1 : memref<32xf32, #quidditch_snitch.l1_encoding>, memref<2x64xf64, strided<[65, 1]>, #quidditch_snitch.l1_encoding> }