Skip to content

Commit

Permalink
[LowerL1Allocations] Support non-identity strides in L1 `memref.alloc…
Browse files Browse the repository at this point in the history
…a` (#104)

The previous lowering would crash in this case as `memref.view` op does
not support a non-identity result `memref`. This PR fixes the lowering
by first creating a `view` into a buffer with enough elements to support
a given layout (i.e. includes padding) before `reinterpret_cast`ing to
the original layout of the `alloca`.
  • Loading branch information
zero9178 authored Jul 29, 2024
1 parent b8a674d commit 3ac4716
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,48 @@ void LowerL1Allocations::runOnOperation() {

auto byteShift =
builder.create<arith::ConstantIndexOp>(allocOp.getLoc(), offset);
// Get rid of the memory space at this point in the pipeline.
auto viewOp = builder.create<memref::ViewOp>(

// We do not support anything but a zero offset right now.
[[maybe_unused]] int64_t ignoredOffset;
SmallVector<int64_t> strides;
if (failed(getStridesAndOffset(memRefType, strides, ignoredOffset))) {
allocOp->emitOpError(
"Cannot lower MemRef in L1 memory with a non-strided layout");
signalPassFailure();
return;
}
if (ignoredOffset != 0) {
allocOp->emitOpError(
"Cannot lower MemRef in L1 memory with a non-zero offset");
signalPassFailure();
return;
}

// Compute how many elements we need to allocate to support the memory
// layout. This may contain padding elements due to the strides.
// Compute this via the linearized access of the last element + 1.
int64_t allocElements = 1;
for (auto [stride, shape] : llvm::zip_equal(strides, memRefType.getShape()))
allocElements += stride * (shape - 1);

// First, allocate one large contiguous element memref.
// Get rid of the memory space at this point as well.
Value view = builder.create<memref::ViewOp>(
allocOp.getLoc(),
MemRefType::get({allocElements}, memRefType.getElementType()), l1Memory,
byteShift,
/*sizes=*/ValueRange());

// Reinterpret cast the view with the actual shape and strides.
view = builder.create<memref::ReinterpretCastOp>(
allocOp.getLoc(),
MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
memRefType.getLayout()),
l1Memory, byteShift,
/*sizes=*/ValueRange());
allocOp->replaceAllUsesWith(viewOp);
view, 0, memRefType.getShape(), strides);
allocOp.replaceAllUsesWith(view);

uint64_t memRefSize = llvm::divideCeil(bitWidth, 8);
for (uint64_t size : memRefType.getShape())
memRefSize *= size;
memRefSize *= allocElements;

offset += memRefSize;
if (offset >= l1MemoryBytes) {
Expand Down
20 changes: 14 additions & 6 deletions codegen/tests/Dialect/Snitch/Transforms/lower-l1-allocations.mlir
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
// RUN: quidditch-opt %s -p "builtin.module(func.func(quidditch-lower-l1-allocations))" | FileCheck %s

// CHECK-LABEL: @test()
// CHECK-SAME: -> (memref<32xf32>, memref<64xf64>)
// CHECK-SAME: -> (memref<32xf32>, memref<2x64xf64, strided<[65, 1]>>)
func.func @test() -> (memref<32xf32, #quidditch_snitch.l1_encoding>,
memref<64xf64, #quidditch_snitch.l1_encoding>) {
memref<2x64xf64, strided<[65, 1]>, #quidditch_snitch.l1_encoding>) {
// CHECK: %[[VIEW:.*]] = quidditch_snitch.l1_memory_view
// CHECK: %[[OFFSET:.*]] = arith.constant 0
// CHECK: %[[ALLOCA0:.*]] = memref.view %[[VIEW]][%[[OFFSET]]][] : memref<{{.*}}xi8> to memref<32xf32>
// CHECK: %[[VIEW0:.*]] = memref.view %[[VIEW]][%[[OFFSET]]][] : memref<{{.*}}xi8> to memref<32xf32>
// CHECK: %[[ALLOCA0:.*]] = memref.reinterpret_cast %[[VIEW0]]
// CHECK-SAME: offset: [0]
// CHECK-SAME: sizes: [32]
// CHECK-SAME: strides: [1]
%0 = memref.alloca() : memref<32xf32, #quidditch_snitch.l1_encoding>
// CHECK: %[[OFFSET:.*]] = arith.constant 128
// CHECK: %[[ALLOCA1:.*]] = memref.view %[[VIEW]][%[[OFFSET]]][] : memref<{{.*}}xi8> to memref<64xf64>
%1 = memref.alloca() : memref<64xf64, #quidditch_snitch.l1_encoding>
// CHECK: %[[VIEW1:.*]] = memref.view %[[VIEW]][%[[OFFSET]]][] : memref<{{.*}}xi8> to memref<129xf64>
// CHECK: %[[ALLOCA1:.*]] = memref.reinterpret_cast %[[VIEW1]]
// CHECK-SAME: offset: [0]
// CHECK-SAME: sizes: [2, 64]
// CHECK-SAME: strides: [65, 1]
%1 = memref.alloca() : memref<2x64xf64, strided<[65, 1]>, #quidditch_snitch.l1_encoding>
// CHECK: return %[[ALLOCA0]], %[[ALLOCA1]]
return %0, %1 : memref<32xf32, #quidditch_snitch.l1_encoding>, memref<64xf64, #quidditch_snitch.l1_encoding>
return %0, %1 : memref<32xf32, #quidditch_snitch.l1_encoding>, memref<2x64xf64, strided<[65, 1]>, #quidditch_snitch.l1_encoding>
}

0 comments on commit 3ac4716

Please sign in to comment.