From 3f9391d9d4475efe2517604964dcc835f8112ca9 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Mon, 19 Feb 2024 16:07:04 -0600 Subject: [PATCH] Implement scalarization of StreamRead/WriteOp and StreamExpand/CollapseShapeOp --- .../HLS/Transforms/MaterializeStream.cpp | 12 +-- .../HLS/Transforms/ScalarizeStream.cpp | 102 ++++++++++++++++-- 2 files changed, 101 insertions(+), 13 deletions(-) diff --git a/lib/Dialect/HLS/Transforms/MaterializeStream.cpp b/lib/Dialect/HLS/Transforms/MaterializeStream.cpp index 7637cf3f..5d9b7338 100644 --- a/lib/Dialect/HLS/Transforms/MaterializeStream.cpp +++ b/lib/Dialect/HLS/Transforms/MaterializeStream.cpp @@ -59,16 +59,16 @@ getSliceInfo(ArrayRef ivs, ArrayRef indexExprs, /// shape of the packed tensor is (1, 1, 1, d0, d1, d2), which means the /// reassociation indices list is [[0, 1, 2, 3], [4], [5]]. static SmallVector getPackingReassociation(int64_t rank) { - SmallVector reassociations; + SmallVector reassociation; for (int64_t i = 0; i < rank; i++) { - ReassociationIndices reassociation; + ReassociationIndices reassociationIndices; if (i == 0) - reassociation = + reassociationIndices = llvm::map_to_vector(llvm::seq(rank), [&](int64_t j) { return j; }); - reassociation.push_back(rank + i); - reassociations.push_back(reassociation); + reassociationIndices.push_back(rank + i); + reassociation.push_back(reassociationIndices); } - return reassociations; + return reassociation; } /// Extract a slice from the tensor and write to the stream channel. If diff --git a/lib/Dialect/HLS/Transforms/ScalarizeStream.cpp b/lib/Dialect/HLS/Transforms/ScalarizeStream.cpp index c9037d7e..fcd56e08 100644 --- a/lib/Dialect/HLS/Transforms/ScalarizeStream.cpp +++ b/lib/Dialect/HLS/Transforms/ScalarizeStream.cpp @@ -44,11 +44,10 @@ struct ScalarizeStreamOp : public OpRewritePattern { LogicalResult matchAndRewrite(hls::StreamOp channel, PatternRewriter &rewriter) const override { auto streamType = channel.getType(); - auto scalarStreamType = getScalarStreamType(streamType); - if (streamType == scalarStreamType) + if (!streamType.hasShapedElementType()) return failure(); - channel.getResult().setType(scalarStreamType); + channel.getResult().setType(getScalarStreamType(streamType)); rewriter.setInsertionPointAfter(channel); auto cast = rewriter.create(channel.getLoc(), streamType, channel); @@ -64,7 +63,30 @@ struct ScalarizeStreamReadOp : public OpRewritePattern { LogicalResult matchAndRewrite(hls::StreamReadOp read, PatternRewriter &rewriter) const override { - return failure(); + auto streamType = read.getChannel().getType(); + if (!streamType.hasShapedElementType()) + return failure(); + + auto loc = read.getLoc(); + rewriter.setInsertionPointAfterValue(read.getChannel()); + auto cast = rewriter.create( + loc, getScalarStreamType(streamType), read.getChannel()); + + rewriter.setInsertionPoint(read); + auto elementType = streamType.getShapedElementType(); + auto init = rewriter.create(loc, elementType); + auto [ivs, result, iterArg] = constructLoops( + elementType.getShape(), SmallVector(elementType.getRank(), 1), + loc, rewriter, init); + + auto scalarRead = rewriter.create( + loc, elementType.getElementType(), cast); + auto insert = rewriter.create(loc, scalarRead.getResult(), + iterArg, ivs); + rewriter.create(loc, insert.getResult()); + + rewriter.replaceOp(read, result); + return success(); } }; } // namespace @@ -75,11 +97,47 @@ struct ScalarizeStreamWriteOp : public OpRewritePattern { LogicalResult matchAndRewrite(hls::StreamWriteOp write, PatternRewriter &rewriter) const override { - return failure(); + auto streamType = write.getChannel().getType(); + if (!streamType.hasShapedElementType()) + return failure(); + + auto loc = write.getLoc(); + rewriter.setInsertionPointAfterValue(write.getChannel()); + auto cast = rewriter.create( + loc, getScalarStreamType(streamType), write.getChannel()); + + rewriter.setInsertionPoint(write); + auto elementType = streamType.getShapedElementType(); + auto [ivs, result, iterArg] = constructLoops( + elementType.getShape(), SmallVector(elementType.getRank(), 1), + loc, rewriter); + + auto extract = + rewriter.create(loc, write.getValue(), ivs); + rewriter.create(loc, cast, extract.getResult()); + + rewriter.eraseOp(write); + return success(); } }; } // namespace +static SmallVector +getScalarReassociation(ArrayRef reassociation, + PatternRewriter &rewriter) { + SmallVector scalarReassociation(reassociation); + auto rank = reassociation.back().back() + 1; + for (auto indices : reassociation) { + ReassociationIndices scalarReassociationIndices; + for (auto index : indices) + scalarReassociationIndices.push_back(index + rank); + scalarReassociation.push_back(scalarReassociationIndices); + } + return llvm::map_to_vector(scalarReassociation, [&](auto indices) { + return Attribute(rewriter.getI64ArrayAttr(indices)); + }); +} + namespace { struct ScalarizeStreamExpandShapeOp : public OpRewritePattern { @@ -87,7 +145,22 @@ struct ScalarizeStreamExpandShapeOp LogicalResult matchAndRewrite(hls::StreamExpandShapeOp expandShape, PatternRewriter &rewriter) const override { - return failure(); + auto inputType = expandShape.getInput().getType(); + auto outputType = expandShape.getOutput().getType(); + if (!inputType.hasShapedElementType() || !outputType.hasShapedElementType()) + return failure(); + + auto loc = expandShape.getLoc(); + auto inputCast = rewriter.create( + loc, getScalarStreamType(inputType), expandShape.getInput()); + + auto scalarReassociation = rewriter.getArrayAttr(getScalarReassociation( + expandShape.getReassociationIndices(), rewriter)); + auto scalarExpandShape = rewriter.create( + loc, getScalarStreamType(outputType), inputCast, scalarReassociation); + rewriter.replaceOpWithNewOp(expandShape, outputType, + scalarExpandShape); + return success(); } }; } // namespace @@ -99,7 +172,22 @@ struct ScalarizeStreamCollapseShapeOp LogicalResult matchAndRewrite(hls::StreamCollapseShapeOp collapseShape, PatternRewriter &rewriter) const override { - return failure(); + auto inputType = collapseShape.getInput().getType(); + auto outputType = collapseShape.getOutput().getType(); + if (!inputType.hasShapedElementType() || !outputType.hasShapedElementType()) + return failure(); + + auto loc = collapseShape.getLoc(); + auto inputCast = rewriter.create( + loc, getScalarStreamType(inputType), collapseShape.getInput()); + + auto scalarReassociation = rewriter.getArrayAttr(getScalarReassociation( + collapseShape.getReassociationIndices(), rewriter)); + auto scalarCollapseShape = rewriter.create( + loc, getScalarStreamType(outputType), inputCast, scalarReassociation); + rewriter.replaceOpWithNewOp(collapseShape, outputType, + scalarCollapseShape); + return success(); } }; } // namespace