Skip to content

Commit

Permalink
Const creation simplification when generating conversion ops
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed Oct 12, 2024
1 parent a7f0832 commit e3d11d9
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 65 deletions.
5 changes: 5 additions & 0 deletions include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <functional>
#include <unordered_map>
#include <filesystem>
#include <stdint.h>

using namespace mlir;

Expand Down Expand Up @@ -117,6 +118,10 @@ class MLIRCodeGen : public ASTVisitor {
mlir::Value popExpressionStack();
mlir::Value currentBasePointer;
mlir::Value convertOp(ConstructorExpression* constructorExp, mlir::Value val);
mlir::Value buildBoolConst(bool val);
mlir::Value buildIntConst(uint32_t val, bool isUnsigned);
mlir::Value buildFloatConst(double val, bool isDouble);
mlir::Value buildVecConst(mlir::Value constant, mlir::Type type);
};

}; // namespace codegen
Expand Down
111 changes: 46 additions & 65 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,30 +710,14 @@ mlir::Value MLIRCodeGen::convertOp(ConstructorExpression* constructorExp, mlir::
} else if ((isSIntLike(fromType) && isUIntLike(toType)) || (isUIntLike(fromType) && isSIntLike(toType))) {
expressionStack.push_back(builder.create<spirv::BitcastOp>(builder.getUnknownLoc(), toType, val));
} else if (isBoolLike(fromType) && isIntLike(toType)) {
mlir::Value one;
auto constOne = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(),
mlir::IntegerType::get(&context, 32, isUIntLike(toType) ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed),
isUIntLike(toType) ? builder.getUI32IntegerAttr(1) : builder.getSI32IntegerAttr(1)
);

mlir::Value zero;
auto constZero = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(),
mlir::IntegerType::get(&context, 32, isUIntLike(toType) ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed),
isUIntLike(toType) ? builder.getUI32IntegerAttr(0) : builder.getSI32IntegerAttr(0)
);
mlir::Value one {};
mlir::Value zero {};
auto constOne = buildIntConst(1, isSIntLike(toType));
auto constZero = buildIntConst(0, isSIntLike(toType));

if (fromType.isa<mlir::VectorType>()) {
std::vector<mlir::Value> operandsZero;
std::vector<mlir::Value> operandsOne;

for (int i = 0; i < fromType.dyn_cast<mlir::VectorType>().getShape()[0]; i++) {
operandsZero.push_back(constZero);
operandsOne.push_back(constOne);
}
zero = builder.create<spirv::CompositeConstructOp>(builder.getUnknownLoc(), toType, operandsZero);
one = builder.create<spirv::CompositeConstructOp>(builder.getUnknownLoc(), toType, operandsOne);
zero = buildVecConst(constZero, toType);
one = buildVecConst(constOne, toType);
} else {
one = constOne;
zero = constZero;
Expand All @@ -749,29 +733,13 @@ mlir::Value MLIRCodeGen::convertOp(ConstructorExpression* constructorExp, mlir::
expressionStack.push_back(res);
} else if (isBoolLike(fromType) && isFloatLike(toType)) {
mlir::Value one;
auto constOne = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(),
isF32Like(toType) ? mlir::FloatType::getF32(&context) : mlir::FloatType::getF64(&context),
isF32Like(toType) ? builder.getF32FloatAttr(1.0f) : builder.getF64FloatAttr(1.0)
);

mlir::Value zero;
auto constZero = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(),
isF32Like(toType) ? mlir::FloatType::getF32(&context) : mlir::FloatType::getF64(&context),
isF32Like(toType) ? builder.getF32FloatAttr(0.0f) : builder.getF64FloatAttr(0.0)
);
auto constOne = buildFloatConst(1.0, isF64Like(toType));
auto constZero = buildFloatConst(0.0, isF64Like(toType));

if (fromType.isa<mlir::VectorType>()) {
std::vector<mlir::Value> operandsZero;
std::vector<mlir::Value> operandsOne;

for (int i = 0; i < fromType.dyn_cast<mlir::VectorType>().getShape()[0]; i++) {
operandsZero.push_back(constZero);
operandsOne.push_back(constOne);
}
zero = builder.create<spirv::CompositeConstructOp>(builder.getUnknownLoc(), toType, operandsZero);
one = builder.create<spirv::CompositeConstructOp>(builder.getUnknownLoc(), toType, operandsOne);
zero = buildVecConst(constZero, toType);
one = buildVecConst(constOne, toType);
} else {
one = constOne;
zero = constZero;
Expand All @@ -792,36 +760,18 @@ mlir::Value MLIRCodeGen::convertOp(ConstructorExpression* constructorExp, mlir::
mlir::Value zero;

if (isIntLike(fromType)) {
zero = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(),
mlir::IntegerType::get(&context, 32, isUIntLike(fromType) ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed),
isUIntLike(fromType) ? builder.getUI32IntegerAttr(0) : builder.getSI32IntegerAttr(0)
);

if (fromType.isa<mlir::VectorType>()) {
std::vector<mlir::Value> operandsZero;
zero = buildIntConst(0, isSIntLike(fromType));

for (int i = 0; i < fromType.dyn_cast<mlir::VectorType>().getShape()[0]; i++) {
operandsZero.push_back(zero);
}
zero = builder.create<spirv::CompositeConstructOp>(builder.getUnknownLoc(), fromType, operandsZero);
if (fromType.isa<mlir::VectorType>()) {
zero = buildVecConst(zero, fromType);
}

expressionStack.push_back(builder.create<spirv::INotEqualOp>(builder.getUnknownLoc(), val, zero));
} else if (isFloatLike(fromType)) {
zero = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(),
isF32Like(fromType) ? mlir::FloatType::getF32(&context) : mlir::FloatType::getF64(&context),
isF32Like(fromType) ? builder.getF32FloatAttr(0.0f) : builder.getF64FloatAttr(0.0)
);
zero = buildFloatConst(0.0, isF64Like(fromType));

if (fromType.isa<mlir::VectorType>()) {
std::vector<mlir::Value> operandsZero;

for (int i = 0; i < fromType.dyn_cast<mlir::VectorType>().getShape()[0]; i++) {
operandsZero.push_back(zero);
}
zero = builder.create<spirv::CompositeConstructOp>(builder.getUnknownLoc(), fromType, operandsZero);
zero = buildVecConst(zero, fromType);
}

expressionStack.push_back(builder.create<spirv::FOrdNotEqualOp>(builder.getUnknownLoc(), val, zero));
Expand Down Expand Up @@ -1358,6 +1308,37 @@ mlir::Value MLIRCodeGen::load(mlir::Value val) {
return val;
}

mlir::Value MLIRCodeGen::buildBoolConst(bool val) {
auto type = builder.getIntegerType(1);
return builder.create<spirv::ConstantOp>(builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, val)));
}

mlir::Value MLIRCodeGen::buildIntConst(uint32_t val, bool isSigned) {
return builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(),
mlir::IntegerType::get(&context, 32, isSigned ? mlir::IntegerType::Signed : mlir::IntegerType::Unsigned),
isSigned ? builder.getSI32IntegerAttr(static_cast<int32_t>(val)) : builder.getUI32IntegerAttr(val)
);
}

mlir::Value MLIRCodeGen::buildFloatConst(double val, bool isDouble) {
return builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(),
isDouble ? mlir::FloatType::getF64(&context) : mlir::FloatType::getF32(&context),
isDouble ? builder.getF64FloatAttr(val) : builder.getF32FloatAttr(static_cast<float>(val))
);
}

mlir::Value MLIRCodeGen::buildVecConst(mlir::Value constant, mlir::Type type) {
std::vector<mlir::Value> operands;

for (int i = 0; i < type.dyn_cast<mlir::VectorType>().getShape()[0]; i++) {
operands.push_back(constant);
}

return builder.create<spirv::CompositeConstructOp>(builder.getUnknownLoc(), type, operands);
}

bool MLIRCodeGen::callBuiltIn(CallExpression* exp) {
auto builtinFuncIt = builtInFuncMap.find(exp->getFunctionName());

Expand Down

0 comments on commit e3d11d9

Please sign in to comment.