Skip to content

Commit

Permalink
[aievec] Port mlir-aie #1626, #1627, #1630 (#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental authored Jul 19, 2024
1 parent d8d73c6 commit c1f4984
Show file tree
Hide file tree
Showing 6 changed files with 576 additions and 31 deletions.
185 changes: 185 additions & 0 deletions compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,191 @@ LogicalResult ShuffleOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// MulElemOp and FMAElemOp
//===----------------------------------------------------------------------===//

// MulElemOp and FMAElemOp are structurally similar, except that FMAElem op
// has few extra fields (accumulator, bool flag to indicate if it is fmsub,
// etc.). We create some specializations to print those fields specifically for
// FMAElemOp and MULElemOp.

// Print the accumulator
template <typename T>
void printAccumulator(OpAsmPrinter &p, T op);
template <>
inline void printAccumulator(OpAsmPrinter &p, aievec::FMAElemOp op) {
p << ", " << op.getAcc();
}
template <>
inline void printAccumulator(OpAsmPrinter &p, aievec::MulElemOp op) {}

// Mark fmsub indicator as elided if the FMAElem op is not fmsub
template <typename T>
void elideFMSubAttr(T op, SmallVector<StringRef, 4> &elidedAttrs);
template <>
inline void elideFMSubAttr(aievec::FMAElemOp op,
SmallVector<StringRef, 4> &elidedAttrs) {
if (!op.getFmsub()) elidedAttrs.push_back(op.getSubAttrName());
}

template <>
inline void elideFMSubAttr(aievec::MulElemOp op,
SmallVector<StringRef, 4> &elidedAttrs) {}

// Print out MulElem and FMAElem op.
template <typename T>
static void printMulFMAElemOp(OpAsmPrinter &p, T op) {
// Print the left operand
p << " " << op.getLhs();
// Print the right operand
p << ", " << op.getRhs();
// For fma op, print the accumulator
printAccumulator(p, op);

// Print the attributes, but don't print attributes that are empty strings
SmallVector<StringRef, 4> elidedAttrs;
for (int idx = 0; idx < 2; ++idx) {
elideFMSubAttr(op, elidedAttrs);
}
p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);

// And now print the types
p << " : " << op.getLhs().getType() << ", " << op.getRhs().getType();
p << ", " << op.getResult().getType();
}

void MulElemOp::print(OpAsmPrinter &p) {
printMulFMAElemOp<aievec::MulElemOp>(p, *this);
}

void aievec::FMAElemOp::print(OpAsmPrinter &p) {
printMulFMAElemOp<aievec::FMAElemOp>(p, *this);
}

// Verify MulElem and FMAElem op.
template <typename T>
LogicalResult verifyMulFMAElemOp(T op) {
// Verify the types
auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());

if (!lhsType || !rhsType) return op.emitError("requires vector type");

auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());

if (!resultType) return op.emitError("requires vector type");

// Additional checks for FMAElem op
// Get the width of the underlying scalars of all the vectors
Type ltype = lhsType.getElementType();
Type rtype = rhsType.getElementType();
Type atype = resultType.getElementType();
unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
unsigned atypeWidth = atype.getIntOrFloatBitWidth();

// Checks on the number of lanes
unsigned rhsLanes = getVectorLaneSize(rhsType);
unsigned lhsLanes = getVectorLaneSize(lhsType);

// lane size must match
if (lhsLanes != rhsLanes) {
return op.emitError(
"The number of lanes in lhs operand "
"must be the same as rhs operand");
}

// lhs and rhs vector's element type must match
if (ltype != rtype)
return op.emitError(
"The element type of lhs and rhs "
"operand vectors must match");

// The integer datatype of accumulator must always be greater width
if (isa<IntegerType>(atype)) {
if (!isa<IntegerType>(ltype))
return op.emitError("Integer result must have integer operands");

if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
return op.emitError(
"the element type of accumulator must have "
"wider width than that of the operand vectors");
} else if (isa<FloatType>(atype)) {
if (!isa<FloatType>(ltype))
return op.emitError(
"Floating point result must have "
"floating point operands");
}

return success();
}

LogicalResult aievec::MulElemOp::verify() {
return verifyMulFMAElemOp<aievec::MulElemOp>(*this);
}

LogicalResult aievec::FMAElemOp::verify() {
return verifyMulFMAElemOp<aievec::FMAElemOp>(*this);
}

// Parse MulElem and FMAElem op.
ParseResult parseMulFMAElemOp(OpAsmParser &parser, OperationState &result,
bool isFMAElemOp = true) {
llvm::SMLoc typesLoc;
SmallVector<Type, 3> types;
OpAsmParser::UnresolvedOperand lhs, rhs, acc;

// Parse the lhs and rhs
if (parser.parseOperand(lhs) || parser.parseComma() ||
parser.parseOperand(rhs))
return failure();

// Parse the acc for FMA op
if (isFMAElemOp) {
if (parser.parseComma() || parser.parseOperand(acc)) return failure();
}

// Parse all the attributes and types
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();

// Assert that there are three types: lhs, rhs, and acc
if (types.size() != 3)
return parser.emitError(typesLoc, "requires three types");

// Some verification
VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
if (!lhsType) return parser.emitError(typesLoc, "requires vector type");
VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
if (!rhsType) return parser.emitError(typesLoc, "requires vector type");

// Int ops use the accumulator while float ops use normal vector registers
VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
if (!accType) return parser.emitError(typesLoc, "requires vector type");

// Populate the lhs and rhs operands, and result
if (parser.resolveOperand(lhs, lhsType, result.operands) ||
parser.resolveOperand(rhs, rhsType, result.operands))
return failure();

// Populate acc operand for FMA op
if (isFMAElemOp) {
if (parser.resolveOperand(acc, accType, result.operands)) return failure();
}

return parser.addTypeToList(accType, result.types);
}

ParseResult MulElemOp::parse(OpAsmParser &parser, OperationState &result) {
return parseMulFMAElemOp(parser, result, false);
}

ParseResult FMAElemOp::parse(OpAsmParser &parser, OperationState &result) {
return parseMulFMAElemOp(parser, result, true);
}

#define GET_ATTRDEF_CLASSES
#include "aievec/AIEVecAttributes.cpp.inc"

Expand Down
56 changes: 56 additions & 0 deletions compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,60 @@ def AIEVec_ShuffleOp : AIEVec_Op<"shuffle",
let hasVerifier = 1;
}

def AIEVec_MulElemOp:
AIEVec_Op<"mul_elem", [
Pure,
SameTypeOperands,
SameOperandsShape,
SameOperandsAndResultShape,
isOperandResultTypePairValidForAIE2MulElem<"lhs", "rhs", "result">
]>,
Arguments<(ins
VectorOfLengthAndType<[16, 32], [I8, I16, I32, BF16, F32]>:$lhs,
VectorOfLengthAndType<[16, 32], [I8, I16, I32, BF16, F32]>:$rhs)>,
Results<(outs
VectorOfLengthAndType<[16, 32], [I32, I64, F32]>:$result)> {
let summary = "AIE2 vector element-wise multiply";
let description = [{
AMD-specific multiply operation that multiplies two 1-D vectors in the same channel.
The vector sizes are at least 512 bits.
`$result = `$lhs * $rhs`.
Currently, the following are the supported type combinations:
lhs | rhs | Accumulator
:------------------:|:------------------:|:-----------------:
`vector<32xi8>` | `vector<32xi8>` | `vector<32xi32>`
`vector<32xi16>` | `vector<32xi16>` | `vector<32xi32>`
`vector<16xi32>` | `vector<16xi32>` | `vector<16xi64>`
`vector<16xbf16>` | `vector<16xbf16>` | `vector<16xf32>`
`vector<16xf32>` | `vector<16xf32>` | `vector<16xf32>`'
}];
}

def AIEVec_FMAElemOp :
AIEVec_Op<"mac_elem", [
Pure
]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
DefaultValuedAttr<BoolAttr, "false">:$fmsub)>,
Results<(outs AnyVector:$result)> {
let summary = "AIE2 element-wise vector fused multiply-add";
let description = [{
AMD-specific multiply-add operation. It multiplies two 1-D vectors in the same channel,
and adds the result to an accumulator.
`$result = `$lhs * $rhs + $acc`.
Note: the same operator can be used as fmsub operator by setting the
'fmsub' bool to true.
}];
let builders = [
OpBuilder<(ins "mlir::Value":$lhs, "mlir::Value":$rhs, "mlir::Value":$acc,
"bool":$fmsub),
[{build($_builder, $_state, acc.getType(), lhs, rhs, acc,
fmsub);}]>
];
let extraClassDeclaration = [{
// Get the attribute names
llvm::StringRef getSubAttrName() { return "fmsub"; }
}];
}

#endif // AIEVEC_OPS
Loading

0 comments on commit c1f4984

Please sign in to comment.