From 8783e27e9c78b80d969d58d49242e9dd8b65967e Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Mon, 19 Feb 2024 22:09:31 -0600 Subject: [PATCH] Update op tablegen definition structure --- .../scalehls/Dialect/HLS/IR/HLSInterfaces.td | 12 +- include/scalehls/Dialect/HLS/IR/HLSOps.td | 93 ++++++++------- lib/Dialect/HLS/IR/HLSOps.cpp | 111 ++++++++++++------ .../HLS/Transforms/MaterializeStream.cpp | 18 +-- 4 files changed, 144 insertions(+), 90 deletions(-) diff --git a/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td b/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td index 2ab8869b..72a56be9 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td +++ b/include/scalehls/Dialect/HLS/IR/HLSInterfaces.td @@ -16,13 +16,21 @@ def StreamViewLikeInterface : OpInterface<"StreamViewLikeInterface"> { string cppNamespace = "mlir::scalehls::hls"; let methods = [ - InterfaceMethod<"Return the source of the stream view", + InterfaceMethod<"Return the input stream", "TypedValue", "getInput", (ins), [{ return $_op.getInput(); }]>, - InterfaceMethod<"Return the result of the stream view", + InterfaceMethod<"Return the output stream", "TypedValue", "getOutput", (ins), [{ return $_op.getOutput(); + }]>, + InterfaceMethod<"Return the input stream type", + "StreamType", "getInputType", (ins), [{ + return $_op.getInput().getType(); + }]>, + InterfaceMethod<"Return the output stream type", + "StreamType", "getOutputType", (ins), [{ + return $_op.getOutput().getType(); }]> ]; } diff --git a/include/scalehls/Dialect/HLS/IR/HLSOps.td b/include/scalehls/Dialect/HLS/IR/HLSOps.td index 00146fe9..bd1b1eca 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSOps.td +++ b/include/scalehls/Dialect/HLS/IR/HLSOps.td @@ -118,7 +118,7 @@ def YieldOp : HLSOp<"yield", [NoMemoryEffect, ReturnLike, Terminator, // Stream Operations //===----------------------------------------------------------------------===// -def TensorInitOp : HLSOp<"tensor_init", [NoMemoryEffect]> { +def TensorInitOp : HLSOp<"tensor_init", [Pure]> { let summary = "Initiate a tensor with an optional initialization value"; let arguments = (ins Optional:$init_value); @@ -134,7 +134,7 @@ def TensorInitOp : HLSOp<"tensor_init", [NoMemoryEffect]> { ]; } -def TensorToStreamOp : HLSOp<"tensor_to_stream",[NoMemoryEffect]> { +def TensorToStreamOp : HLSOp<"tensor_to_stream",[Pure]> { let summary = "Convert a tensor to a stream channel"; let arguments = (ins AnyRankedTensor:$tensor); @@ -147,7 +147,7 @@ def TensorToStreamOp : HLSOp<"tensor_to_stream",[NoMemoryEffect]> { let hasFolder = 1; } -def StreamToTensorOp : HLSOp<"stream_to_tensor", [NoMemoryEffect]> { +def StreamToTensorOp : HLSOp<"stream_to_tensor", [Pure]> { let summary = "Convert a stream channel to a tensor"; let arguments = (ins AnyStream:$stream); @@ -196,26 +196,19 @@ def StreamWriteOp : HLSOp<"stream_write", [ let arguments = (ins AnyStream:$channel, AnyType:$value); let assemblyFormat = [{ - $channel `,` $value attr-dict `:` functional-type($value, $channel) + $value `to` $channel attr-dict `:` functional-type($value, $channel) }]; let hasVerifier = 1; } -def StreamExpandShapeOp : HLSOp<"stream_expand_shape", [NoMemoryEffect, - DeclareOpInterfaceMethods]> { - let summary = "Expand the shape of the iteration space"; +class StreamReassociativeOp traits = []> : + HLSOp])>, + Arguments<(ins AnyStream:$input, IndexListArrayAttr:$reassociation)>, + Results<(outs AnyStream:$output)> { - let arguments = (ins AnyStream:$input, IndexListArrayAttr:$reassociation); - let results = (outs AnyStream:$output); - let assemblyFormat = [{ - $input `,` $reassociation attr-dict `:` functional-type($input, $output) - }]; - - let hasVerifier = 1; - let hasFolder = 1; - - let extraClassDeclaration = [{ + code commonExtraClassDeclaration = [{ SmallVector getReassociationIndices() { SmallVector reassociationIndices; for (auto attr : getReassociation()) @@ -225,53 +218,62 @@ def StreamExpandShapeOp : HLSOp<"stream_expand_shape", [NoMemoryEffect, }))); return reassociationIndices; } - }]; -} -def StreamCollapseShapeOp : HLSOp<"stream_collapse_shape", [NoMemoryEffect, - DeclareOpInterfaceMethods]> { - let summary = "Collapse the shape of the iteration space"; + StreamType getInputType() { return getInput().getType(); } + StreamType getOutputType() { return getOutput().getType(); } + }]; - let arguments = (ins AnyStream:$input, IndexListArrayAttr:$reassociation); - let results = (outs AnyStream:$output); let assemblyFormat = [{ - $input `,` $reassociation attr-dict `:` functional-type($input, $output) + $input $reassociation attr-dict `:` functional-type($input, $output) }]; let hasVerifier = 1; let hasFolder = 1; +} - let extraClassDeclaration = [{ - SmallVector getReassociationIndices() { - SmallVector reassociationIndices; - for (auto attr : getReassociation()) - reassociationIndices.push_back(llvm::to_vector<2>( - llvm::map_range(::llvm::cast(attr), [&](Attribute indexAttr) { - return ::llvm::cast(indexAttr).getInt(); - }))); - return reassociationIndices; - } - }]; +def StreamSplitIterationOp : StreamReassociativeOp<"stream_split_iteration"> { + let summary = "Split the iteration space of a stream channel"; + let extraClassDeclaration = commonExtraClassDeclaration; +} + +def StreamMergeIterationOp : StreamReassociativeOp<"stream_merge_iteration"> { + let summary = "Merge the iteration space of a stream channel"; + let extraClassDeclaration = commonExtraClassDeclaration; +} + +def StreamExpandShapeOp : StreamReassociativeOp<"stream_expand_shape"> { + let summary = "Expand the shape of the stream element"; + let extraClassDeclaration = commonExtraClassDeclaration; +} + +def StreamCollapseShapeOp : StreamReassociativeOp<"stream_collapse_shape"> { + let summary = "Collapse the shape of the stream element"; + let extraClassDeclaration = commonExtraClassDeclaration; } -def StreamBufferOp : HLSOp<"stream_buffer", [NoMemoryEffect, - DeclareOpInterfaceMethods]> { +def StreamBufferOp : HLSOp<"stream_buffer", [Pure, + DeclareOpInterfaceMethods]> { let summary = "Buffer a stream channel at a specific position"; let arguments = (ins AnyStream:$input, TypeAttr:$bufferElementType, - DenseI64ArrayAttr:$bufferShape, I64Attr:$beforeLoop, I64Attr:$beforeDim); + DenseI64ArrayAttr:$bufferShape, I64Attr:$loopIndex, I64Attr:$dimIndex); let results = (outs AnyStream:$output); let assemblyFormat = [{ - $input `,` $bufferElementType $bufferShape `loop` $beforeLoop `dim` - $beforeDim attr-dict `:` functional-type($input, $output) + $input `,` $bufferElementType $bufferShape `before` `loop` $loopIndex `dim` + $dimIndex attr-dict `:` functional-type($input, $output) }]; let hasVerifier = 1; let hasFolder = 1; + + let extraClassDeclaration = [{ + StreamType getInputType() { return getInput().getType(); } + StreamType getOutputType() { return getOutput().getType(); } + }]; } -def StreamCastOp : HLSOp<"stream_cast", [NoMemoryEffect, - DeclareOpInterfaceMethods]> { +def StreamCastOp : HLSOp<"stream_cast", [Pure, + DeclareOpInterfaceMethods]> { let summary = "Cast a stream channel to a different type"; let arguments = (ins AnyStream:$input); @@ -282,6 +284,11 @@ def StreamCastOp : HLSOp<"stream_cast", [NoMemoryEffect, let hasVerifier = 1; let hasFolder = 1; + + let extraClassDeclaration = [{ + StreamType getInputType() { return getInput().getType(); } + StreamType getOutputType() { return getOutput().getType(); } + }]; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/HLS/IR/HLSOps.cpp b/lib/Dialect/HLS/IR/HLSOps.cpp index ac66dcce..317d5161 100644 --- a/lib/Dialect/HLS/IR/HLSOps.cpp +++ b/lib/Dialect/HLS/IR/HLSOps.cpp @@ -316,12 +316,13 @@ void StreamWriteOp::getEffects( } //===----------------------------------------------------------------------===// -// StreamExpandShapeOp +// StreamSplitIterationOp //===----------------------------------------------------------------------===// static LogicalResult -verifyReassociation(SmallVectorImpl &reassociation, - StreamType lowType, StreamType highType, Operation *op) { +verifyIterationReassociation(ArrayRef reassociation, + StreamType lowType, StreamType highType, + Operation *op) { if (reassociation.size() != lowType.getIterTripCounts().size()) return op->emitOpError("reassociation size doesn't align with input type"); @@ -338,43 +339,83 @@ verifyReassociation(SmallVectorImpl &reassociation, return op->emitOpError("reassociation doesn't align with input/output " "iteration trip counts or steps"); } - - unsigned highIndex = 0; - for (auto lowExpr : lowType.getIterMap().getResults()) { - if (auto lowDimExpr = dyn_cast(lowExpr)) { - auto indices = reassociation[lowDimExpr.getPosition()]; - for (auto index : indices) { - auto highDimExpr = - dyn_cast(highType.getIterMap().getResult(highIndex)); - if (!highDimExpr || highDimExpr.getPosition() != index) - return op->emitOpError( - "reassociation doesn't align with input/output iteration maps"); - highIndex++; - } - } else - highIndex++; - } return success(); } -LogicalResult StreamExpandShapeOp::verify() { - auto inputType = getInput().getType(); - auto outputType = getOutput().getType(); - if (inputType.getDataType() != outputType.getDataType()) - return emitOpError("input and output data type doesn't match"); - auto reassociation = getReassociationIndices(); - return verifyReassociation(reassociation, inputType, outputType, *this); +LogicalResult StreamSplitIterationOp::verify() { + if (getInputType().isCastableWith(getOutputType())) + return emitOpError("input and output are not castable"); + return verifyIterationReassociation(getReassociationIndices(), getInputType(), + getOutputType(), *this); } -OpFoldResult foldStreamViewLikeInterface(StreamViewLikeInterface op) { +static OpFoldResult foldStreamViewLikeInterface(StreamViewLikeInterface op) { if (op.getInput().getType() == op.getOutput().getType()) return op.getInput(); if (auto prevView = op.getInput().getDefiningOp()) - if (prevView.getInput().getType() == op.getOutput().getType()) + if (prevView.getInputType() == op.getOutput().getType()) return prevView.getInput(); return {}; } +OpFoldResult StreamSplitIterationOp::fold(FoldAdaptor adaptor) { + return foldStreamViewLikeInterface(*this); +} + +//===----------------------------------------------------------------------===// +// StreamMergeIterationOp +//===----------------------------------------------------------------------===// + +LogicalResult StreamMergeIterationOp::verify() { + if (getInputType().isCastableWith(getOutputType())) + return emitOpError("input and output are not castable"); + return verifyIterationReassociation(getReassociationIndices(), getInputType(), + getOutputType(), *this); +} + +OpFoldResult StreamMergeIterationOp::fold(FoldAdaptor adaptor) { + return foldStreamViewLikeInterface(*this); +} + +//===----------------------------------------------------------------------===// +// StreamExpandShapeOp +//===----------------------------------------------------------------------===// + +static LogicalResult +verifyShapeReassociation(ArrayRef reassociation, + StreamType lowType, StreamType highType, + Operation *op) { + // if (lowType.getIterTripCounts() != lowType.getIterTripCounts() || + // lowType.getIterSteps() != lowType.getIterSteps()) + // return op->emitOpError("input and output iteration trip counts or steps " + // "doesn't match"); + + // unsigned highIndex = 0; + // for (auto lowExpr : lowType.getIterMap().getResults()) { + // if (auto lowDimExpr = dyn_cast(lowExpr)) { + // auto indices = reassociation[lowDimExpr.getPosition()]; + // for (auto index : indices) { + // auto highDimExpr = + // dyn_cast(highType.getIterMap().getResult(highIndex)); + // if (!highDimExpr || highDimExpr.getPosition() != index) + // return op->emitOpError( + // "reassociation doesn't align with input/output iteration + // maps"); + // highIndex++; + // } + // } else + // highIndex++; + // } + return success(); +} + +LogicalResult StreamExpandShapeOp::verify() { + if (getInputType().getDataType() != getOutputType().getDataType()) + return emitOpError("input and output data type doesn't match"); + return verifyShapeReassociation(getReassociationIndices(), getInputType(), + getOutputType(), *this); +} + OpFoldResult StreamExpandShapeOp::fold(FoldAdaptor adaptor) { return foldStreamViewLikeInterface(*this); } @@ -384,12 +425,10 @@ OpFoldResult StreamExpandShapeOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// LogicalResult StreamCollapseShapeOp::verify() { - auto inputType = getInput().getType(); - auto outputType = getOutput().getType(); - if (inputType.getDataType() != outputType.getDataType()) + if (getInputType().getDataType() != getOutputType().getDataType()) return emitOpError("input and output data type doesn't match"); - auto reassociation = getReassociationIndices(); - return verifyReassociation(reassociation, outputType, inputType, *this); + return verifyShapeReassociation(getReassociationIndices(), getInputType(), + getOutputType(), *this); } OpFoldResult StreamCollapseShapeOp::fold(FoldAdaptor adaptor) { @@ -406,14 +445,14 @@ LogicalResult StreamBufferOp::verify() { if (!inputType.isCastableWith(outputType)) return emitOpError("input and output are not castable"); - if (getBeforeLoop() > inputType.getIterTripCounts().size()) - return emitOpError("buffer position is out of loop range"); + if (getLoopIndex() > inputType.getIterTripCounts().size()) + return emitOpError("buffer loop index is out of loop range"); auto inputShape = inputType.getShape(); for (auto [dim, bufferSize, dimSize, inputTileSize, outputTileSize] : llvm::zip(llvm::seq(inputShape.size()), getBufferShape(), inputShape, inputType.getElementShape(), outputType.getElementShape())) { - if (dim < getBeforeDim()) { + if (dim < getDimIndex()) { if (inputTileSize != outputTileSize || bufferSize < inputTileSize) return emitOpError( "buffer size is smaller than input/output tile size"); diff --git a/lib/Dialect/HLS/Transforms/MaterializeStream.cpp b/lib/Dialect/HLS/Transforms/MaterializeStream.cpp index 5d9b7338..b102bfe5 100644 --- a/lib/Dialect/HLS/Transforms/MaterializeStream.cpp +++ b/lib/Dialect/HLS/Transforms/MaterializeStream.cpp @@ -254,7 +254,7 @@ struct LowerStreamBufferOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto inputType = streamBuffer.getInput().getType(); auto outputType = streamBuffer.getOutput().getType(); - auto beforeLoop = streamBuffer.getBeforeLoop(); + auto loopIndex = streamBuffer.getLoopIndex(); auto loc = streamBuffer.getLoc(); // Construct the output stream channel. @@ -263,8 +263,8 @@ struct LowerStreamBufferOp : public OpRewritePattern { // Construct loops to iterate over the dimensions shared by input stream and // output stream. for (auto [tripCount, step] : - llvm::zip(inputType.getIterTripCounts().take_front(beforeLoop), - inputType.getIterSteps().take_front(beforeLoop))) { + llvm::zip(inputType.getIterTripCounts().take_front(loopIndex), + inputType.getIterSteps().take_front(loopIndex))) { auto [lbCst, ubCst, stepCst] = getLoopBoundsAndStep(tripCount, step, loc, rewriter); auto loop = rewriter.create(loc, lbCst, ubCst, stepCst); @@ -293,10 +293,10 @@ struct LowerStreamBufferOp : public OpRewritePattern { // element to the buffer tensor. auto zeroCst = rewriter.create(loc, 0); auto [inputIvs, inputResult, inputIterArg] = - constructLoops(inputType.getIterTripCounts().drop_front(beforeLoop), - inputType.getIterSteps().drop_front(beforeLoop), loc, + constructLoops(inputType.getIterTripCounts().drop_front(loopIndex), + inputType.getIterSteps().drop_front(loopIndex), loc, rewriter, init.getResult()); - SmallVector bufferInputIvs(beforeLoop, zeroCst); + SmallVector bufferInputIvs(loopIndex, zeroCst); bufferInputIvs.append(inputIvs); readStreamAndInsertSlice(bufferInputIvs, streamBuffer.getInput(), inputIterArg, packing, loc, rewriter, @@ -306,9 +306,9 @@ struct LowerStreamBufferOp : public OpRewritePattern { // write to the output stream channel. rewriter.setInsertionPointAfterValue(inputResult); auto [outputIvs, outputResult, outputIterArg] = constructLoops( - outputType.getIterTripCounts().drop_front(beforeLoop), - outputType.getIterSteps().drop_front(beforeLoop), loc, rewriter); - SmallVector bufferOutputIvs(beforeLoop, zeroCst); + outputType.getIterTripCounts().drop_front(loopIndex), + outputType.getIterSteps().drop_front(loopIndex), loc, rewriter); + SmallVector bufferOutputIvs(loopIndex, zeroCst); bufferOutputIvs.append(outputIvs); extactSliceAndWriteStream(bufferOutputIvs, channel, inputResult, packing, loc, rewriter);