diff --git a/include/CodeGen/MLIRCodeGen.h b/include/CodeGen/MLIRCodeGen.h index b6c6e50..397bd8d 100644 --- a/include/CodeGen/MLIRCodeGen.h +++ b/include/CodeGen/MLIRCodeGen.h @@ -14,6 +14,7 @@ #include "llvm/ADT/ScopedHashTable.h" #include #include +#include using namespace mlir; @@ -35,7 +36,7 @@ class MLIRCodeGen : public ASTVisitor { public: MLIRCodeGen(); void initModuleOp(); - void dump(); + void print(); bool verify(); void visit(TranslationUnit *) override; void visit(BinaryExpression *) override; @@ -80,7 +81,7 @@ class MLIRCodeGen : public ASTVisitor { bool inGlobalScope = true; llvm::StringMap functionMap; llvm::StringMap structDeclarations; - std::vector expressionStack; + std::vector> expressionStack; StructDeclaration* currentBaseComposite = nullptr; llvm::ScopedHashTable @@ -95,7 +96,7 @@ class MLIRCodeGen : public ASTVisitor { void createVariable(shaderpulse::Type *, VariableDeclaration *); void insertEntryPoint(); - mlir::Value popExpressionStack(); + std::pair popExpressionStack(); mlir::Value currentBasePointer; Type* typeContext; }; diff --git a/lib/CodeGen/MLIRCodeGen.cpp b/lib/CodeGen/MLIRCodeGen.cpp index 987cf8a..9671792 100644 --- a/lib/CodeGen/MLIRCodeGen.cpp +++ b/lib/CodeGen/MLIRCodeGen.cpp @@ -31,8 +31,8 @@ void MLIRCodeGen::initModuleOp() { spirvModule = cast(Operation::create(state)); } -void MLIRCodeGen::dump() { - spirvModule.dump(); +void MLIRCodeGen::print() { + spirvModule.print(llvm::outs()); } bool MLIRCodeGen::verify() { return !failed(mlir::verify(spirvModule)); } @@ -54,8 +54,8 @@ void MLIRCodeGen::visit(TranslationUnit *unit) { insertEntryPoint(); } -Value MLIRCodeGen::popExpressionStack() { - Value val = expressionStack.back(); +std::pair MLIRCodeGen::popExpressionStack() { + auto val = expressionStack.back(); expressionStack.pop_back(); return val; } @@ -64,12 +64,12 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { binExp->getLhs()->accept(this); binExp->getRhs()->accept(this); - Value rhs = popExpressionStack(); - Value lhs = popExpressionStack(); + std::pair rhsPair = popExpressionStack(); + std::pair lhsPair = popExpressionStack(); - // TODO: handle other types than float. Need to figure out the expression this - // BinaryExpression is part of to know what kind of spirv op to use. (float, - // int?) + Value rhs = rhsPair.second; + Value lhs = lhsPair.second; + Type* typeContext = lhsPair.first; // TODO: implement source location auto loc = builder.getUnknownLoc(); @@ -82,7 +82,7 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::Sub: if (typeContext->isIntLike()) { @@ -91,7 +91,7 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::Mul: if (typeContext->isIntLike()) { @@ -100,7 +100,7 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::Div: if (typeContext->isUintLike()) { @@ -111,7 +111,7 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::Mod: if (typeContext->isIntLike()) { @@ -120,62 +120,94 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::ShiftLeft: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::ShiftRight: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::Lt: - val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + if (typeContext->isFloatLike()) { + val = builder.create(loc, lhs, rhs); + } else if (typeContext->isUintLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::Gt: - val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + if (typeContext->isFloatLike()) { + val = builder.create(loc, lhs, rhs); + } else if (typeContext->isUintLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::LtEq: - val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + if (typeContext->isFloatLike()) { + val = builder.create(loc, lhs, rhs); + } else if (typeContext->isUintLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::GtEq: - val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + if (typeContext->isFloatLike()) { + val = builder.create(loc, lhs, rhs); + } else if (typeContext->isUintLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::Eq: - val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + if (typeContext->isFloatLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::Neq: - val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + if (typeContext->isFloatLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::BitAnd: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::BitXor: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::BitIor: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::LogAnd: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; case BinaryOperator::LogXor: // TODO: not implemented in current spirv dialect break; case BinaryOperator::LogOr: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(typeContext, val)); break; } } @@ -185,18 +217,18 @@ void MLIRCodeGen::visit(ConditionalExpression *condExp) { condExp->getTruePart()->accept(this); condExp->getCondition()->accept(this); - Value condition = popExpressionStack(); - Value truePart = popExpressionStack(); - Value falsePart = popExpressionStack(); + std::pair condition = popExpressionStack(); + std::pair truePart = popExpressionStack(); + std::pair falsePart = popExpressionStack(); Value res = builder.create( builder.getUnknownLoc(), - /* Harcoded, fix me */ mlir::FloatType::getF32(&context), - condition, - truePart, - falsePart); + convertShaderPulseType(&context, truePart.first, structDeclarations), + condition.second, + truePart.second, + falsePart.second); - expressionStack.push_back(res); + expressionStack.push_back(std::make_pair(truePart.first, res)); } void MLIRCodeGen::visit(ForStatement *forStmt) { @@ -213,12 +245,12 @@ void MLIRCodeGen::visit(InitializerExpression *initExp) { builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(32, 0, true))); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(nullptr, val)); } void MLIRCodeGen::visit(UnaryExpression *unExp) { unExp->getExpression()->accept(this); - Value rhs = popExpressionStack(); + std::pair rhs = popExpressionStack(); auto loc = builder.getUnknownLoc(); Value val; @@ -234,16 +266,20 @@ void MLIRCodeGen::visit(UnaryExpression *unExp) { expressionStack.push_back(rhs); break; case UnaryOperator::Dash: - val = builder.create(loc, rhs); - expressionStack.push_back(val); + if (rhs.first->isFloatLike()) { + val = builder.create(loc, rhs.second); + } else { + val = builder.create(loc, rhs.second); + } + expressionStack.push_back(std::make_pair(rhs.first, val)); break; case UnaryOperator::Bang: - val = builder.create(loc, rhs); - expressionStack.push_back(val); + val = builder.create(loc, rhs.second); + expressionStack.push_back(std::make_pair(rhs.first, val)); break; case UnaryOperator::Tilde: - val = builder.create(loc, rhs); - expressionStack.push_back(val); + val = builder.create(loc, rhs.second); + expressionStack.push_back(std::make_pair(rhs.first, val)); break; } } @@ -252,8 +288,6 @@ void MLIRCodeGen::declare(SymbolTableEntry entry) { if (symbolTable.count(entry.variable->getIdentifierName())) return; - - std::cout << "Declaring " << entry.variable->getIdentifierName() << std::endl; symbolTable.insert(entry.variable->getIdentifierName(), entry); } @@ -286,7 +320,7 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type, Operation *initializerOp = nullptr; if (expressionStack.size() > 0) { - Value val = popExpressionStack(); + Value val = popExpressionStack().second; initializerOp = val.getDefiningOp(); } @@ -309,14 +343,13 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type, // builder.getUnitAttr()); } else { if (varDecl->getInitialzerExpression()) { - std::cout << "Accept init" << std::endl; varDecl->getInitialzerExpression()->accept(this); } Value val; if (expressionStack.size() > 0) { - val = popExpressionStack(); + val = popExpressionStack().second; } spirv::PointerType ptrType = spirv::PointerType::get( @@ -346,7 +379,7 @@ void MLIRCodeGen::visit(WhileStatement *whileStmt) { whileStmt->getCondition()->accept(this); - auto conditionOp = popExpressionStack(); + auto conditionOp = popExpressionStack().second; auto loc = builder.getUnknownLoc(); @@ -392,7 +425,7 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { if (constructorExp->getArguments().size() > 0) { for (auto &arg : constructorExp->getArguments()) { arg->accept(this); - operands.push_back(popExpressionStack()); + operands.push_back(popExpressionStack().second); } } @@ -403,7 +436,7 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { if (structDeclarations.find(structName) != structDeclarations.end()) { mlir::Value val = builder.create( builder.getUnknownLoc(), convertShaderPulseType(&context, constructorType, structDeclarations), operands); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(constructorType, val)); } break; @@ -413,7 +446,7 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { case TypeKind::Array: { mlir::Value val = builder.create( builder.getUnknownLoc(), convertShaderPulseType(&context, constructorType, structDeclarations), operands); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(constructorType, val)); break; } @@ -438,7 +471,7 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { mlir::Value val = builder.create( builder.getUnknownLoc(), convertShaderPulseType(&context, constructorType, structDeclarations), columnVectors); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(constructorType, val)); break; } @@ -451,29 +484,31 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { void MLIRCodeGen::visit(ArrayAccessExpression *arrayAccess) { auto array = arrayAccess->getArray(); array->accept(this); - Value mlirArray = popExpressionStack(); + std::pair mlirArray = popExpressionStack(); + Type* elementType = dynamic_cast(mlirArray.first)->getElementType(); std::vector indices; for (auto &access : arrayAccess->getAccessChain()) { access->accept(this); - indices.push_back(popExpressionStack()); + indices.push_back(popExpressionStack().second); } - Value accessChain = builder.create(builder.getUnknownLoc(), mlirArray, indices); + Value accessChain = builder.create(builder.getUnknownLoc(), mlirArray.second, indices); if (arrayAccess->isLhs()) { - expressionStack.push_back(accessChain); + expressionStack.push_back(std::make_pair(elementType, accessChain)); } else { - expressionStack.push_back(builder.create(builder.getUnknownLoc(), accessChain)->getResult(0)); + expressionStack.push_back(std::make_pair(elementType, builder.create(builder.getUnknownLoc(), accessChain)->getResult(0))); } } void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) { auto baseComposite = memberAccess->getBaseComposite(); baseComposite->accept(this); - Value baseCompositeValue = popExpressionStack(); + Value baseCompositeValue = popExpressionStack().second; std::vector memberIndices; std::vector memberIndicesAcc; + Type* memberType; if (currentBaseComposite) { for (auto &member : memberAccess->getMembers()) { @@ -489,15 +524,17 @@ void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) { currentBaseComposite = structDeclarations[structName]; } } + + memberType = memberIndexPair.second->getType(); } } if (memberAccess->isLhs()) { Value accessChain = builder.create(builder.getUnknownLoc(), baseCompositeValue, memberIndicesAcc); - expressionStack.push_back(accessChain); + expressionStack.push_back(std::make_pair(memberType, accessChain)); } else { Value compositeElement = builder.create(builder.getUnknownLoc(), baseCompositeValue, memberIndices); - expressionStack.push_back(compositeElement); + expressionStack.push_back(std::make_pair(memberType, compositeElement)); } } } @@ -518,7 +555,7 @@ void MLIRCodeGen::visit(IfStatement *ifStmt) { auto loc = builder.getUnknownLoc(); ifStmt->getCondition()->accept(this); - Value condition = popExpressionStack(); + Value condition = popExpressionStack().second; auto selectionOp = builder.create(loc, spirv::SelectionControl::None); selectionOp.addMergeBlock(); @@ -563,8 +600,8 @@ void MLIRCodeGen::visit(AssignmentExpression *assignmentExp) { assignmentExp->getUnaryExpression()->accept(this); assignmentExp->getExpression()->accept(this); - Value val = popExpressionStack(); - Value ptr = popExpressionStack(); + Value val = popExpressionStack().second; + Value ptr = popExpressionStack().second; builder.create(builder.getUnknownLoc(), ptr, val); } @@ -584,7 +621,7 @@ void MLIRCodeGen::visit(CallExpression *callExp) { if (callExp->getArguments().size() > 0) { for (auto &arg : callExp->getArguments()) { arg->accept(this); - operands.push_back(popExpressionStack()); + operands.push_back(popExpressionStack().second); } } @@ -594,7 +631,8 @@ void MLIRCodeGen::visit(CallExpression *callExp) { builder.getUnknownLoc(), calledFunc.getFunctionType().getResults(), SymbolRefAttr::get(&context, calledFunc.getSymName()), operands); - expressionStack.push_back(funcCall.getResult(0)); + // TODO: get return type of callee + expressionStack.push_back(std::make_pair(nullptr, funcCall.getResult(0))); } else { std::cout << "Function not found." << callExp->getFunctionName() << std::endl; @@ -602,11 +640,9 @@ void MLIRCodeGen::visit(CallExpression *callExp) { } void MLIRCodeGen::visit(VariableExpression *varExp) { - std::cout << "Looking up " << varExp->getName() << std::endl; auto entry = symbolTable.lookup(varExp->getName()); if (entry.variable) { - std::cout << "Looked up and found " << varExp->getName() << std::endl; Value val; if (entry.isGlobal) { @@ -629,7 +665,7 @@ void MLIRCodeGen::visit(VariableExpression *varExp) { } } - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(entry.variable->getType(), val)); } else { std::cout << "Unable to find variable: " << varExp->getName() << std::endl; } @@ -641,7 +677,7 @@ void MLIRCodeGen::visit(IntegerConstantExpression *intConstExp) { builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(32, intConstExp->getVal(), true))); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(intConstExp->getType(), val)); } void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) { @@ -650,7 +686,7 @@ void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) { builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(32, uintConstExp->getVal(), false))); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(uintConstExp->getType(), val)); } void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) { @@ -659,7 +695,7 @@ void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) { builder.getUnknownLoc(), type, FloatAttr::get(type, APFloat(floatConstExp->getVal()))); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(floatConstExp->getType(), val)); } void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) { @@ -668,7 +704,7 @@ void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) { builder.getUnknownLoc(), type, FloatAttr::get(type, APFloat(doubleConstExp->getVal()))); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(doubleConstExp->getType(), val)); } void MLIRCodeGen::visit(BoolConstantExpression *boolConstExp) { @@ -677,7 +713,7 @@ void MLIRCodeGen::visit(BoolConstantExpression *boolConstExp) { builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, boolConstExp->getVal()))); - expressionStack.push_back(val); + expressionStack.push_back(std::make_pair(boolConstExp->getType(), val)); } void MLIRCodeGen::visit(ReturnStatement *returnStmt) { @@ -687,7 +723,7 @@ void MLIRCodeGen::visit(ReturnStatement *returnStmt) { if (expressionStack.empty()) { builder.create(builder.getUnknownLoc()); } else { - Value val = popExpressionStack(); + Value val = popExpressionStack().second; builder.create(builder.getUnknownLoc(), val); } } diff --git a/standalone/shaderpulse.cpp b/standalone/shaderpulse.cpp index ea7d9aa..26ed621 100644 --- a/standalone/shaderpulse.cpp +++ b/standalone/shaderpulse.cpp @@ -7,27 +7,6 @@ #include #include -static std::string functionDeclarationTestString = -R"( - uniform highp float a; - uniform int b; - uint c; - vec3 d; - mat2x2 e; - - float foo() { - return 1.0; - } - - float myFunc(vec2 arg1, bool arg2) { - float f; - float g; - f = 1.0; - g = f + 2.0; - return g + foo(); - } -)"; - using namespace shaderpulse; using namespace shaderpulse::ast; using namespace shaderpulse::lexer; @@ -39,9 +18,10 @@ int main(int argc, char** argv) { std::cout << "Missing input file." << std::endl; return -1; } - + bool printAST = false; bool codeGen = true; + bool analyze = true; for (size_t i = 2; i < argc; i++) { std::string arg = argv[i]; @@ -50,13 +30,14 @@ int main(int argc, char** argv) { printAST = true; } else if (arg == "--no-codegen") { codeGen = false; + } else if (arg == "--no-analyze") { + analyze = false; } else { std::cout << "Unrecognized argument: '" << arg << "'." << std::endl; return -1; } } - std::ifstream glslIn(argv[1]); std::stringstream shaderCodeBuffer; shaderCodeBuffer << glslIn.rdbuf(); @@ -65,7 +46,6 @@ int main(int argc, char** argv) { auto preprocessor = preprocessor::Preprocessor(sourceCode); preprocessor.process(); auto processedCode = preprocessor.getProcessedSource(); - std::cout << processedCode; auto lexer = Lexer(processedCode); auto resp = lexer.lexCharacterStream(); if (!resp.has_value()) { @@ -83,14 +63,17 @@ int main(int argc, char** argv) { } if (codeGen) { - auto analyzer = SemanticAnalyzer(); - translationUnit->accept(&analyzer); + if (analyze) { + auto analyzer = SemanticAnalyzer(); + translationUnit->accept(&analyzer); + } + auto mlirCodeGen = codegen::MLIRCodeGen(); translationUnit->accept(&mlirCodeGen); - mlirCodeGen.dump(); + mlirCodeGen.print(); - if (mlirCodeGen.verify()) { - std::cout << "SPIR-V module verified" << std::endl; + if (!mlirCodeGen.verify()) { + std::cout << "Error verifying the SPIR-V module" << std::endl; } } diff --git a/test/CodeGen/binary_expressions.glsl b/test/CodeGen/binary_expressions.glsl new file mode 100644 index 0000000..1dea6a9 --- /dev/null +++ b/test/CodeGen/binary_expressions.glsl @@ -0,0 +1,15 @@ +void main() { + // CHECK: %0 = spirv.IAdd %cst1_si32, %cst2_si32 : si3 + int a = 1 + 2; + + // CHECK: %2 = spirv.IAdd %cst1_ui32, %cst2_ui32 : ui32 + uint c = 1u + 2u; + + // CHECK: %4 = spirv.FAdd %cst_f32, %cst_f32_0 : f32 + float b = 1.0f + 2.0f; + + // CHECK: %6 = spirv.FAdd %cst_f64, %cst_f64_1 : f64 + double d = 1.0lf + 2.0lf; + + return; +} diff --git a/test/CodeGen/run_test.sh b/test/CodeGen/run_test.sh new file mode 100755 index 0000000..e60b8f6 --- /dev/null +++ b/test/CodeGen/run_test.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +set -e + +SHADERPULSE="../../build/shaderpulse-standalone" +FILECHECK="../../llvm-project/build/bin/FileCheck" + +if [ ! -x "$SHADERPULSE" ]; then + echo "Error: shaderpulse binary not found at $SHADERPULSE" + exit 1 +fi + +if [ ! -x "$FILECHECK" ]; then + echo "Error: FileCheck binary not found at $FILECHECK" + exit 1 +fi + +for TEST_FILE in *.glsl; do + if [ ! -f "$TEST_FILE" ]; then + echo "No .glsl files found in the current directory." + exit 1 + fi + + echo "Running test on $TEST_FILE" + $SHADERPULSE "$TEST_FILE" --no-analyze | $FILECHECK "$TEST_FILE" + + if [ $? -eq 0 ]; then + echo "Test passed for $TEST_FILE" + else + echo "Test failed for $TEST_FILE" + exit 1 + fi +done