diff --git a/lib/CodeGen/MLIRCodeGen.cpp b/lib/CodeGen/MLIRCodeGen.cpp index 5476170..fda91f4 100644 --- a/lib/CodeGen/MLIRCodeGen.cpp +++ b/lib/CodeGen/MLIRCodeGen.cpp @@ -462,11 +462,7 @@ void MLIRCodeGen::visit(UnaryExpression *unExp) { case UnaryOperator::Inc: case UnaryOperator::Dec: { if (isIntLike(rhsType)) { - auto one = builder.create( - builder.getUnknownLoc(), - mlir::IntegerType::get(&context, 32, isUIntLike(rhsType) ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed), - isUIntLike(rhsType) ? builder.getUI32IntegerAttr(1) :builder.getSI32IntegerAttr(1) - ); + auto one = buildIntConst(1, isSIntLike(rhsType)); if (op == UnaryOperator::Inc) { result = builder.create(loc, rhs, one); @@ -474,7 +470,7 @@ void MLIRCodeGen::visit(UnaryExpression *unExp) { result = builder.create(loc, rhs, one); } } else { - auto one = builder.create(builder.getUnknownLoc(), mlir::FloatType::getF32(&context), builder.getF32FloatAttr(1.0f)); + auto one = buildFloatConst(1.0, false); if (op == UnaryOperator::Inc) { result = builder.create(loc, rhs, one); @@ -1033,48 +1029,28 @@ void MLIRCodeGen::visit(VariableExpression *varExp) { } void MLIRCodeGen::visit(IntegerConstantExpression *intConstExp) { - auto type = builder.getIntegerType(32, true); - mlir::Value val = builder.create( - builder.getUnknownLoc(), type, - IntegerAttr::get(type, APInt(32, intConstExp->getVal(), true))); - - expressionStack.push_back(val); + auto constVal = buildIntConst(intConstExp->getVal(), true); + expressionStack.push_back(constVal); } void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) { - auto type = builder.getIntegerType(32, false); - mlir::Value val = builder.create( - builder.getUnknownLoc(), type, - IntegerAttr::get(type, APInt(32, uintConstExp->getVal(), false))); - - expressionStack.push_back(val); + auto constVal = buildIntConst(uintConstExp->getVal(), false); + expressionStack.push_back(constVal); } void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) { - auto type = builder.getF32Type(); - mlir::Value val = builder.create( - builder.getUnknownLoc(), type, - FloatAttr::get(type, APFloat(floatConstExp->getVal()))); - - expressionStack.push_back(val); + auto constVal = buildFloatConst(floatConstExp->getVal(), false); + expressionStack.push_back(constVal); } void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) { - auto type = builder.getF64Type(); - mlir::Value val = builder.create( - builder.getUnknownLoc(), type, - FloatAttr::get(type, APFloat(doubleConstExp->getVal()))); - - expressionStack.push_back(val); + auto constVal = buildFloatConst(doubleConstExp->getVal(), true); + expressionStack.push_back(constVal); } void MLIRCodeGen::visit(BoolConstantExpression *boolConstExp) { - auto type = builder.getIntegerType(1); - mlir::Value val = builder.create( - builder.getUnknownLoc(), type, - IntegerAttr::get(type, APInt(1, boolConstExp->getVal()))); - - expressionStack.push_back(val); + auto constVal = buildBoolConst(boolConstExp->getVal()); + expressionStack.push_back(constVal); } void MLIRCodeGen::visit(ReturnStatement *returnStmt) { @@ -1100,13 +1076,12 @@ void MLIRCodeGen::visit(ContinueStatement *continueStmt) { } void MLIRCodeGen::setBoolVar(mlir::spirv::VariableOp var, bool val) { - auto type = builder.getIntegerType(1); - mlir::Value constTrue = builder.create(builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, val))); - builder.create(builder.getUnknownLoc(), var, constTrue); + auto constBool = buildBoolConst(val); + builder.create(builder.getUnknownLoc(), var, constBool); } void MLIRCodeGen::visit(DiscardStatement *discardStmt) { - + // TODO: implement me } void MLIRCodeGen::visit(FunctionDeclaration *funcDecl) {