diff --git a/include/CodeGen/MLIRCodeGen.h b/include/CodeGen/MLIRCodeGen.h index 3b69b35..1c88b37 100644 --- a/include/CodeGen/MLIRCodeGen.h +++ b/include/CodeGen/MLIRCodeGen.h @@ -111,7 +111,7 @@ class MLIRCodeGen : public ASTVisitor { void initBuiltinFuncMap(); bool callBuiltIn(CallExpression* exp); void createBuiltinComputeVar(const std::string &varName, const std::string &mlirName); - void generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt); + void generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt, bool isDoWhile = false); void setBoolVar(mlir::spirv::VariableOp var, bool val); mlir::Value load(mlir::Value); mlir::Value popExpressionStack(); diff --git a/lib/AST/PrinterASTVisitor.cpp b/lib/AST/PrinterASTVisitor.cpp index 976c96d..dc4b953 100644 --- a/lib/AST/PrinterASTVisitor.cpp +++ b/lib/AST/PrinterASTVisitor.cpp @@ -83,10 +83,12 @@ void PrinterASTVisitor::visit(WhileStatement *whileStmt) { void PrinterASTVisitor::visit(DoStatement *doStmt) { print("|-DoStatement " + loc(doStmt->getSourceLocation())); - doStmt->getCondition()->accept(this); indent(); doStmt->getBody()->accept(this); resetIndent(); + indent(); + doStmt->getCondition()->accept(this); + resetIndent(); } void PrinterASTVisitor::visit(IfStatement *ifStmt) { diff --git a/lib/CodeGen/MLIRCodeGen.cpp b/lib/CodeGen/MLIRCodeGen.cpp index ca569f6..262c1a2 100644 --- a/lib/CodeGen/MLIRCodeGen.cpp +++ b/lib/CodeGen/MLIRCodeGen.cpp @@ -943,7 +943,7 @@ void MLIRCodeGen::visit(InterfaceBlock *interfaceBlock) { } void MLIRCodeGen::visit(DoStatement *doStmt) { - // TODO: implement me + generateLoop(nullptr, doStmt->getCondition(), nullptr, doStmt->getBody(), /*isDoWhile*/ true); } void MLIRCodeGen::visit(IfStatement *ifStmt) { @@ -1230,7 +1230,7 @@ void MLIRCodeGen::visit(DefaultLabel *defaultLabel) {} void MLIRCodeGen::visit(CaseLabel *defaultLabel) {} -void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt) { +void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt, bool isDoWhile) { Block *restoreInsertionBlock = builder.getInsertionBlock(); SymbolTableScopeT varScope(symbolTable); @@ -1277,9 +1277,14 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E builder.setInsertionPointToEnd(header); Block *merge = loopOp.getMergeBlock(); - conditionExpr->accept(this); - auto conditionOp = load(popExpressionStack()); - builder.create(loc, conditionOp, body, ArrayRef(), merge, ArrayRef()); + if (isDoWhile) { + builder.create(loc, &*std::next(loopOp.getBody().begin(), 2)); + } else { + conditionExpr->accept(this); + auto conditionOp = load(popExpressionStack()); + builder.create(loc, conditionOp, body, ArrayRef(), merge, ArrayRef()); + } + builder.setInsertionPointToStart(body); @@ -1332,7 +1337,14 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E inductionExpr->accept(this); } - builder.create(loc, header); + if (isDoWhile) { + conditionExpr->accept(this); + auto conditionOp = load(popExpressionStack()); + builder.create(loc, conditionOp, header, ArrayRef(), merge, ArrayRef()); + } else { + builder.create(loc, header); + } + builder.setInsertionPointToEnd(restoreInsertionBlock); breakStack.pop_back(); continueStack.pop_back(); diff --git a/test/CodeGen/cf_do_while.glsl b/test/CodeGen/cf_do_while.glsl new file mode 100644 index 0000000..fd788a5 --- /dev/null +++ b/test/CodeGen/cf_do_while.glsl @@ -0,0 +1,29 @@ +void main() { + bool a = true; + int b; + + // Header jumps unconditionally to loop body + + // CHECK: spirv.mlir.loop { + // CHECK-NEXT: spirv.Branch ^bb1 + // CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb3 + // CHECK-NEXT: spirv.Branch ^bb2 + do { + int c = 2; + int d = 3; + b = c + d; + + // The condition check happens in the continue block + + // CHECK: ^bb3: // pred: ^bb2 + // CHECK-NEXT: %false_1 = spirv.Constant false + // CHECK-NEXT: spirv.Store "Function" %3, %false_1 : i1 + // CHECK-NEXT: %10 = spirv.Load "Function" %0 : i1 + // CHECK-NEXT: spirv.BranchConditional %10, ^bb1, ^bb4 + // CHECK-NEXT: ^bb4: // pred: ^bb3 + // CHECK-NEXT: spirv.mlir.merge + // CHECK-NEXT: } + } while(a); + + int someVarAfter = 12; +}