Skip to content

Commit

Permalink
Update op tablegen definition structure
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Feb 20, 2024
1 parent efd7f0e commit 8783e27
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 90 deletions.
12 changes: 10 additions & 2 deletions include/scalehls/Dialect/HLS/IR/HLSInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@ def StreamViewLikeInterface : OpInterface<"StreamViewLikeInterface"> {
string cppNamespace = "mlir::scalehls::hls";

let methods = [
InterfaceMethod<"Return the source of the stream view",
InterfaceMethod<"Return the input stream",
"TypedValue<StreamType>", "getInput", (ins), [{
return $_op.getInput();
}]>,
InterfaceMethod<"Return the result of the stream view",
InterfaceMethod<"Return the output stream",
"TypedValue<StreamType>", "getOutput", (ins), [{
return $_op.getOutput();
}]>,
InterfaceMethod<"Return the input stream type",
"StreamType", "getInputType", (ins), [{
return $_op.getInput().getType();
}]>,
InterfaceMethod<"Return the output stream type",
"StreamType", "getOutputType", (ins), [{
return $_op.getOutput().getType();
}]>
];
}
Expand Down
93 changes: 50 additions & 43 deletions include/scalehls/Dialect/HLS/IR/HLSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def YieldOp : HLSOp<"yield", [NoMemoryEffect, ReturnLike, Terminator,
// Stream Operations
//===----------------------------------------------------------------------===//

def TensorInitOp : HLSOp<"tensor_init", [NoMemoryEffect]> {
def TensorInitOp : HLSOp<"tensor_init", [Pure]> {
let summary = "Initiate a tensor with an optional initialization value";

let arguments = (ins Optional<AnyType>:$init_value);
Expand All @@ -134,7 +134,7 @@ def TensorInitOp : HLSOp<"tensor_init", [NoMemoryEffect]> {
];
}

def TensorToStreamOp : HLSOp<"tensor_to_stream",[NoMemoryEffect]> {
def TensorToStreamOp : HLSOp<"tensor_to_stream",[Pure]> {
let summary = "Convert a tensor to a stream channel";

let arguments = (ins AnyRankedTensor:$tensor);
Expand All @@ -147,7 +147,7 @@ def TensorToStreamOp : HLSOp<"tensor_to_stream",[NoMemoryEffect]> {
let hasFolder = 1;
}

def StreamToTensorOp : HLSOp<"stream_to_tensor", [NoMemoryEffect]> {
def StreamToTensorOp : HLSOp<"stream_to_tensor", [Pure]> {
let summary = "Convert a stream channel to a tensor";

let arguments = (ins AnyStream:$stream);
Expand Down Expand Up @@ -196,26 +196,19 @@ def StreamWriteOp : HLSOp<"stream_write", [

let arguments = (ins AnyStream:$channel, AnyType:$value);
let assemblyFormat = [{
$channel `,` $value attr-dict `:` functional-type($value, $channel)
$value `to` $channel attr-dict `:` functional-type($value, $channel)
}];

let hasVerifier = 1;
}

def StreamExpandShapeOp : HLSOp<"stream_expand_shape", [NoMemoryEffect,
DeclareOpInterfaceMethods<StreamViewLikeInterface>]> {
let summary = "Expand the shape of the iteration space";
class StreamReassociativeOp<string mnemonic, list<Trait> traits = []> :
HLSOp<mnemonic, !listconcat(traits, [Pure,
DeclareOpInterfaceMethods<StreamViewLikeInterface>])>,
Arguments<(ins AnyStream:$input, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStream:$output)> {

let arguments = (ins AnyStream:$input, IndexListArrayAttr:$reassociation);
let results = (outs AnyStream:$output);
let assemblyFormat = [{
$input `,` $reassociation attr-dict `:` functional-type($input, $output)
}];

let hasVerifier = 1;
let hasFolder = 1;

let extraClassDeclaration = [{
code commonExtraClassDeclaration = [{
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
SmallVector<ReassociationIndices, 4> reassociationIndices;
for (auto attr : getReassociation())
Expand All @@ -225,53 +218,62 @@ def StreamExpandShapeOp : HLSOp<"stream_expand_shape", [NoMemoryEffect,
})));
return reassociationIndices;
}
}];
}

def StreamCollapseShapeOp : HLSOp<"stream_collapse_shape", [NoMemoryEffect,
DeclareOpInterfaceMethods<StreamViewLikeInterface>]> {
let summary = "Collapse the shape of the iteration space";
StreamType getInputType() { return getInput().getType(); }
StreamType getOutputType() { return getOutput().getType(); }
}];

let arguments = (ins AnyStream:$input, IndexListArrayAttr:$reassociation);
let results = (outs AnyStream:$output);
let assemblyFormat = [{
$input `,` $reassociation attr-dict `:` functional-type($input, $output)
$input $reassociation attr-dict `:` functional-type($input, $output)
}];

let hasVerifier = 1;
let hasFolder = 1;
}

let extraClassDeclaration = [{
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
SmallVector<ReassociationIndices, 4> reassociationIndices;
for (auto attr : getReassociation())
reassociationIndices.push_back(llvm::to_vector<2>(
llvm::map_range(::llvm::cast<ArrayAttr>(attr), [&](Attribute indexAttr) {
return ::llvm::cast<IntegerAttr>(indexAttr).getInt();
})));
return reassociationIndices;
}
}];
def StreamSplitIterationOp : StreamReassociativeOp<"stream_split_iteration"> {
let summary = "Split the iteration space of a stream channel";
let extraClassDeclaration = commonExtraClassDeclaration;
}

def StreamMergeIterationOp : StreamReassociativeOp<"stream_merge_iteration"> {
let summary = "Merge the iteration space of a stream channel";
let extraClassDeclaration = commonExtraClassDeclaration;
}

def StreamExpandShapeOp : StreamReassociativeOp<"stream_expand_shape"> {
let summary = "Expand the shape of the stream element";
let extraClassDeclaration = commonExtraClassDeclaration;
}

def StreamCollapseShapeOp : StreamReassociativeOp<"stream_collapse_shape"> {
let summary = "Collapse the shape of the stream element";
let extraClassDeclaration = commonExtraClassDeclaration;
}

def StreamBufferOp : HLSOp<"stream_buffer", [NoMemoryEffect,
DeclareOpInterfaceMethods<StreamViewLikeInterface>]> {
def StreamBufferOp : HLSOp<"stream_buffer", [Pure,
DeclareOpInterfaceMethods<StreamViewLikeInterface>]> {
let summary = "Buffer a stream channel at a specific position";

let arguments = (ins AnyStream:$input, TypeAttr:$bufferElementType,
DenseI64ArrayAttr:$bufferShape, I64Attr:$beforeLoop, I64Attr:$beforeDim);
DenseI64ArrayAttr:$bufferShape, I64Attr:$loopIndex, I64Attr:$dimIndex);
let results = (outs AnyStream:$output);
let assemblyFormat = [{
$input `,` $bufferElementType $bufferShape `loop` $beforeLoop `dim`
$beforeDim attr-dict `:` functional-type($input, $output)
$input `,` $bufferElementType $bufferShape `before` `loop` $loopIndex `dim`
$dimIndex attr-dict `:` functional-type($input, $output)
}];

let hasVerifier = 1;
let hasFolder = 1;

let extraClassDeclaration = [{
StreamType getInputType() { return getInput().getType(); }
StreamType getOutputType() { return getOutput().getType(); }
}];
}

def StreamCastOp : HLSOp<"stream_cast", [NoMemoryEffect,
DeclareOpInterfaceMethods<StreamViewLikeInterface>]> {
def StreamCastOp : HLSOp<"stream_cast", [Pure,
DeclareOpInterfaceMethods<StreamViewLikeInterface>]> {
let summary = "Cast a stream channel to a different type";

let arguments = (ins AnyStream:$input);
Expand All @@ -282,6 +284,11 @@ def StreamCastOp : HLSOp<"stream_cast", [NoMemoryEffect,

let hasVerifier = 1;
let hasFolder = 1;

let extraClassDeclaration = [{
StreamType getInputType() { return getInput().getType(); }
StreamType getOutputType() { return getOutput().getType(); }
}];
}

//===----------------------------------------------------------------------===//
Expand Down
111 changes: 75 additions & 36 deletions lib/Dialect/HLS/IR/HLSOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,13 @@ void StreamWriteOp::getEffects(
}

//===----------------------------------------------------------------------===//
// StreamExpandShapeOp
// StreamSplitIterationOp
//===----------------------------------------------------------------------===//

static LogicalResult
verifyReassociation(SmallVectorImpl<ReassociationIndices> &reassociation,
StreamType lowType, StreamType highType, Operation *op) {
verifyIterationReassociation(ArrayRef<ReassociationIndices> reassociation,
StreamType lowType, StreamType highType,
Operation *op) {
if (reassociation.size() != lowType.getIterTripCounts().size())
return op->emitOpError("reassociation size doesn't align with input type");

Expand All @@ -338,43 +339,83 @@ verifyReassociation(SmallVectorImpl<ReassociationIndices> &reassociation,
return op->emitOpError("reassociation doesn't align with input/output "
"iteration trip counts or steps");
}

unsigned highIndex = 0;
for (auto lowExpr : lowType.getIterMap().getResults()) {
if (auto lowDimExpr = dyn_cast<AffineDimExpr>(lowExpr)) {
auto indices = reassociation[lowDimExpr.getPosition()];
for (auto index : indices) {
auto highDimExpr =
dyn_cast<AffineDimExpr>(highType.getIterMap().getResult(highIndex));
if (!highDimExpr || highDimExpr.getPosition() != index)
return op->emitOpError(
"reassociation doesn't align with input/output iteration maps");
highIndex++;
}
} else
highIndex++;
}
return success();
}

LogicalResult StreamExpandShapeOp::verify() {
auto inputType = getInput().getType();
auto outputType = getOutput().getType();
if (inputType.getDataType() != outputType.getDataType())
return emitOpError("input and output data type doesn't match");
auto reassociation = getReassociationIndices();
return verifyReassociation(reassociation, inputType, outputType, *this);
LogicalResult StreamSplitIterationOp::verify() {
if (getInputType().isCastableWith(getOutputType()))
return emitOpError("input and output are not castable");
return verifyIterationReassociation(getReassociationIndices(), getInputType(),
getOutputType(), *this);
}

OpFoldResult foldStreamViewLikeInterface(StreamViewLikeInterface op) {
static OpFoldResult foldStreamViewLikeInterface(StreamViewLikeInterface op) {
if (op.getInput().getType() == op.getOutput().getType())
return op.getInput();
if (auto prevView = op.getInput().getDefiningOp<StreamViewLikeInterface>())
if (prevView.getInput().getType() == op.getOutput().getType())
if (prevView.getInputType() == op.getOutput().getType())
return prevView.getInput();
return {};
}

OpFoldResult StreamSplitIterationOp::fold(FoldAdaptor adaptor) {
return foldStreamViewLikeInterface(*this);
}

//===----------------------------------------------------------------------===//
// StreamMergeIterationOp
//===----------------------------------------------------------------------===//

LogicalResult StreamMergeIterationOp::verify() {
if (getInputType().isCastableWith(getOutputType()))
return emitOpError("input and output are not castable");
return verifyIterationReassociation(getReassociationIndices(), getInputType(),
getOutputType(), *this);
}

OpFoldResult StreamMergeIterationOp::fold(FoldAdaptor adaptor) {
return foldStreamViewLikeInterface(*this);
}

//===----------------------------------------------------------------------===//
// StreamExpandShapeOp
//===----------------------------------------------------------------------===//

static LogicalResult
verifyShapeReassociation(ArrayRef<ReassociationIndices> reassociation,
StreamType lowType, StreamType highType,
Operation *op) {
// if (lowType.getIterTripCounts() != lowType.getIterTripCounts() ||
// lowType.getIterSteps() != lowType.getIterSteps())
// return op->emitOpError("input and output iteration trip counts or steps "
// "doesn't match");

// unsigned highIndex = 0;
// for (auto lowExpr : lowType.getIterMap().getResults()) {
// if (auto lowDimExpr = dyn_cast<AffineDimExpr>(lowExpr)) {
// auto indices = reassociation[lowDimExpr.getPosition()];
// for (auto index : indices) {
// auto highDimExpr =
// dyn_cast<AffineDimExpr>(highType.getIterMap().getResult(highIndex));
// if (!highDimExpr || highDimExpr.getPosition() != index)
// return op->emitOpError(
// "reassociation doesn't align with input/output iteration
// maps");
// highIndex++;
// }
// } else
// highIndex++;
// }
return success();
}

LogicalResult StreamExpandShapeOp::verify() {
if (getInputType().getDataType() != getOutputType().getDataType())
return emitOpError("input and output data type doesn't match");
return verifyShapeReassociation(getReassociationIndices(), getInputType(),
getOutputType(), *this);
}

OpFoldResult StreamExpandShapeOp::fold(FoldAdaptor adaptor) {
return foldStreamViewLikeInterface(*this);
}
Expand All @@ -384,12 +425,10 @@ OpFoldResult StreamExpandShapeOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

LogicalResult StreamCollapseShapeOp::verify() {
auto inputType = getInput().getType();
auto outputType = getOutput().getType();
if (inputType.getDataType() != outputType.getDataType())
if (getInputType().getDataType() != getOutputType().getDataType())
return emitOpError("input and output data type doesn't match");
auto reassociation = getReassociationIndices();
return verifyReassociation(reassociation, outputType, inputType, *this);
return verifyShapeReassociation(getReassociationIndices(), getInputType(),
getOutputType(), *this);
}

OpFoldResult StreamCollapseShapeOp::fold(FoldAdaptor adaptor) {
Expand All @@ -406,14 +445,14 @@ LogicalResult StreamBufferOp::verify() {
if (!inputType.isCastableWith(outputType))
return emitOpError("input and output are not castable");

if (getBeforeLoop() > inputType.getIterTripCounts().size())
return emitOpError("buffer position is out of loop range");
if (getLoopIndex() > inputType.getIterTripCounts().size())
return emitOpError("buffer loop index is out of loop range");

auto inputShape = inputType.getShape();
for (auto [dim, bufferSize, dimSize, inputTileSize, outputTileSize] :
llvm::zip(llvm::seq(inputShape.size()), getBufferShape(), inputShape,
inputType.getElementShape(), outputType.getElementShape())) {
if (dim < getBeforeDim()) {
if (dim < getDimIndex()) {
if (inputTileSize != outputTileSize || bufferSize < inputTileSize)
return emitOpError(
"buffer size is smaller than input/output tile size");
Expand Down
18 changes: 9 additions & 9 deletions lib/Dialect/HLS/Transforms/MaterializeStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ struct LowerStreamBufferOp : public OpRewritePattern<hls::StreamBufferOp> {
PatternRewriter &rewriter) const override {
auto inputType = streamBuffer.getInput().getType();
auto outputType = streamBuffer.getOutput().getType();
auto beforeLoop = streamBuffer.getBeforeLoop();
auto loopIndex = streamBuffer.getLoopIndex();
auto loc = streamBuffer.getLoc();

// Construct the output stream channel.
Expand All @@ -263,8 +263,8 @@ struct LowerStreamBufferOp : public OpRewritePattern<hls::StreamBufferOp> {
// Construct loops to iterate over the dimensions shared by input stream and
// output stream.
for (auto [tripCount, step] :
llvm::zip(inputType.getIterTripCounts().take_front(beforeLoop),
inputType.getIterSteps().take_front(beforeLoop))) {
llvm::zip(inputType.getIterTripCounts().take_front(loopIndex),
inputType.getIterSteps().take_front(loopIndex))) {
auto [lbCst, ubCst, stepCst] =
getLoopBoundsAndStep(tripCount, step, loc, rewriter);
auto loop = rewriter.create<scf::ForOp>(loc, lbCst, ubCst, stepCst);
Expand Down Expand Up @@ -293,10 +293,10 @@ struct LowerStreamBufferOp : public OpRewritePattern<hls::StreamBufferOp> {
// element to the buffer tensor.
auto zeroCst = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto [inputIvs, inputResult, inputIterArg] =
constructLoops(inputType.getIterTripCounts().drop_front(beforeLoop),
inputType.getIterSteps().drop_front(beforeLoop), loc,
constructLoops(inputType.getIterTripCounts().drop_front(loopIndex),
inputType.getIterSteps().drop_front(loopIndex), loc,
rewriter, init.getResult());
SmallVector<Value> bufferInputIvs(beforeLoop, zeroCst);
SmallVector<Value> bufferInputIvs(loopIndex, zeroCst);
bufferInputIvs.append(inputIvs);
readStreamAndInsertSlice(bufferInputIvs, streamBuffer.getInput(),
inputIterArg, packing, loc, rewriter,
Expand All @@ -306,9 +306,9 @@ struct LowerStreamBufferOp : public OpRewritePattern<hls::StreamBufferOp> {
// write to the output stream channel.
rewriter.setInsertionPointAfterValue(inputResult);
auto [outputIvs, outputResult, outputIterArg] = constructLoops(
outputType.getIterTripCounts().drop_front(beforeLoop),
outputType.getIterSteps().drop_front(beforeLoop), loc, rewriter);
SmallVector<Value> bufferOutputIvs(beforeLoop, zeroCst);
outputType.getIterTripCounts().drop_front(loopIndex),
outputType.getIterSteps().drop_front(loopIndex), loc, rewriter);
SmallVector<Value> bufferOutputIvs(loopIndex, zeroCst);
bufferOutputIvs.append(outputIvs);
extactSliceAndWriteStream(bufferOutputIvs, channel, inputResult, packing,
loc, rewriter);
Expand Down

0 comments on commit 8783e27

Please sign in to comment.