Skip to content

Commit

Permalink
Implement scalarization of StreamRead/WriteOp and StreamExpand/Collap…
Browse files Browse the repository at this point in the history
…seShapeOp
  • Loading branch information
hanchenye committed Feb 19, 2024
1 parent 18d8423 commit 3f9391d
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 13 deletions.
12 changes: 6 additions & 6 deletions lib/Dialect/HLS/Transforms/MaterializeStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ getSliceInfo(ArrayRef<Value> ivs, ArrayRef<AffineExpr> indexExprs,
/// shape of the packed tensor is (1, 1, 1, d0, d1, d2), which means the
/// reassociation indices list is [[0, 1, 2, 3], [4], [5]].
static SmallVector<ReassociationIndices> getPackingReassociation(int64_t rank) {
SmallVector<ReassociationIndices> reassociations;
SmallVector<ReassociationIndices> reassociation;
for (int64_t i = 0; i < rank; i++) {
ReassociationIndices reassociation;
ReassociationIndices reassociationIndices;
if (i == 0)
reassociation =
reassociationIndices =
llvm::map_to_vector(llvm::seq(rank), [&](int64_t j) { return j; });
reassociation.push_back(rank + i);
reassociations.push_back(reassociation);
reassociationIndices.push_back(rank + i);
reassociation.push_back(reassociationIndices);
}
return reassociations;
return reassociation;
}

/// Extract a slice from the tensor and write to the stream channel. If
Expand Down
102 changes: 95 additions & 7 deletions lib/Dialect/HLS/Transforms/ScalarizeStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ struct ScalarizeStreamOp : public OpRewritePattern<hls::StreamOp> {
LogicalResult matchAndRewrite(hls::StreamOp channel,
PatternRewriter &rewriter) const override {
auto streamType = channel.getType();
auto scalarStreamType = getScalarStreamType(streamType);
if (streamType == scalarStreamType)
if (!streamType.hasShapedElementType())
return failure();

channel.getResult().setType(scalarStreamType);
channel.getResult().setType(getScalarStreamType(streamType));
rewriter.setInsertionPointAfter(channel);
auto cast = rewriter.create<hls::StreamCastOp>(channel.getLoc(), streamType,
channel);
Expand All @@ -64,7 +63,30 @@ struct ScalarizeStreamReadOp : public OpRewritePattern<hls::StreamReadOp> {

LogicalResult matchAndRewrite(hls::StreamReadOp read,
PatternRewriter &rewriter) const override {
return failure();
auto streamType = read.getChannel().getType();
if (!streamType.hasShapedElementType())
return failure();

auto loc = read.getLoc();
rewriter.setInsertionPointAfterValue(read.getChannel());
auto cast = rewriter.create<hls::StreamCastOp>(
loc, getScalarStreamType(streamType), read.getChannel());

rewriter.setInsertionPoint(read);
auto elementType = streamType.getShapedElementType();
auto init = rewriter.create<hls::TensorInitOp>(loc, elementType);
auto [ivs, result, iterArg] = constructLoops(
elementType.getShape(), SmallVector<int64_t>(elementType.getRank(), 1),
loc, rewriter, init);

auto scalarRead = rewriter.create<hls::StreamReadOp>(
loc, elementType.getElementType(), cast);
auto insert = rewriter.create<tensor::InsertOp>(loc, scalarRead.getResult(),
iterArg, ivs);
rewriter.create<scf::YieldOp>(loc, insert.getResult());

rewriter.replaceOp(read, result);
return success();
}
};
} // namespace
Expand All @@ -75,19 +97,70 @@ struct ScalarizeStreamWriteOp : public OpRewritePattern<hls::StreamWriteOp> {

LogicalResult matchAndRewrite(hls::StreamWriteOp write,
PatternRewriter &rewriter) const override {
return failure();
auto streamType = write.getChannel().getType();
if (!streamType.hasShapedElementType())
return failure();

auto loc = write.getLoc();
rewriter.setInsertionPointAfterValue(write.getChannel());
auto cast = rewriter.create<hls::StreamCastOp>(
loc, getScalarStreamType(streamType), write.getChannel());

rewriter.setInsertionPoint(write);
auto elementType = streamType.getShapedElementType();
auto [ivs, result, iterArg] = constructLoops(
elementType.getShape(), SmallVector<int64_t>(elementType.getRank(), 1),
loc, rewriter);

auto extract =
rewriter.create<tensor::ExtractOp>(loc, write.getValue(), ivs);
rewriter.create<hls::StreamWriteOp>(loc, cast, extract.getResult());

rewriter.eraseOp(write);
return success();
}
};
} // namespace

static SmallVector<Attribute>
getScalarReassociation(ArrayRef<ReassociationIndices> reassociation,
PatternRewriter &rewriter) {
SmallVector<ReassociationIndices> scalarReassociation(reassociation);
auto rank = reassociation.back().back() + 1;
for (auto indices : reassociation) {
ReassociationIndices scalarReassociationIndices;
for (auto index : indices)
scalarReassociationIndices.push_back(index + rank);
scalarReassociation.push_back(scalarReassociationIndices);
}
return llvm::map_to_vector(scalarReassociation, [&](auto indices) {
return Attribute(rewriter.getI64ArrayAttr(indices));
});
}

namespace {
struct ScalarizeStreamExpandShapeOp
: public OpRewritePattern<hls::StreamExpandShapeOp> {
using OpRewritePattern<hls::StreamExpandShapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(hls::StreamExpandShapeOp expandShape,
PatternRewriter &rewriter) const override {
return failure();
auto inputType = expandShape.getInput().getType();
auto outputType = expandShape.getOutput().getType();
if (!inputType.hasShapedElementType() || !outputType.hasShapedElementType())
return failure();

auto loc = expandShape.getLoc();
auto inputCast = rewriter.create<hls::StreamCastOp>(
loc, getScalarStreamType(inputType), expandShape.getInput());

auto scalarReassociation = rewriter.getArrayAttr(getScalarReassociation(
expandShape.getReassociationIndices(), rewriter));
auto scalarExpandShape = rewriter.create<hls::StreamExpandShapeOp>(
loc, getScalarStreamType(outputType), inputCast, scalarReassociation);
rewriter.replaceOpWithNewOp<hls::StreamCastOp>(expandShape, outputType,
scalarExpandShape);
return success();
}
};
} // namespace
Expand All @@ -99,7 +172,22 @@ struct ScalarizeStreamCollapseShapeOp

LogicalResult matchAndRewrite(hls::StreamCollapseShapeOp collapseShape,
PatternRewriter &rewriter) const override {
return failure();
auto inputType = collapseShape.getInput().getType();
auto outputType = collapseShape.getOutput().getType();
if (!inputType.hasShapedElementType() || !outputType.hasShapedElementType())
return failure();

auto loc = collapseShape.getLoc();
auto inputCast = rewriter.create<hls::StreamCastOp>(
loc, getScalarStreamType(inputType), collapseShape.getInput());

auto scalarReassociation = rewriter.getArrayAttr(getScalarReassociation(
collapseShape.getReassociationIndices(), rewriter));
auto scalarCollapseShape = rewriter.create<hls::StreamCollapseShapeOp>(
loc, getScalarStreamType(outputType), inputCast, scalarReassociation);
rewriter.replaceOpWithNewOp<hls::StreamCastOp>(collapseShape, outputType,
scalarCollapseShape);
return success();
}
};
} // namespace
Expand Down

0 comments on commit 3f9391d

Please sign in to comment.