diff --git a/include/CodeGen/MLIRCodeGen.h b/include/CodeGen/MLIRCodeGen.h index 6f2d6b3..45a1f42 100644 --- a/include/CodeGen/MLIRCodeGen.h +++ b/include/CodeGen/MLIRCodeGen.h @@ -90,8 +90,10 @@ class MLIRCodeGen : public ASTVisitor { std::vector expressionStack; StructDeclaration* currentBaseComposite = nullptr; mlir::Operation *execModeOp = nullptr; - mlir::spirv::VariableOp breakGate; + std::vector breakStack; + std::vector continueStack; bool breakDetected = false; + bool continueDetected = false; llvm::ScopedHashTable symbolTable; @@ -109,6 +111,7 @@ class MLIRCodeGen : public ASTVisitor { bool callBuiltIn(CallExpression* exp); void createBuiltinComputeVar(const std::string &varName, const std::string &mlirName); void generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt); + void setBoolVar(mlir::spirv::VariableOp var, bool val); mlir::Value load(mlir::Value); mlir::Value popExpressionStack(); mlir::Value currentBasePointer; diff --git a/lib/CodeGen/MLIRCodeGen.cpp b/lib/CodeGen/MLIRCodeGen.cpp index 3ee0855..b45b983 100644 --- a/lib/CodeGen/MLIRCodeGen.cpp +++ b/lib/CodeGen/MLIRCodeGen.cpp @@ -1120,14 +1120,19 @@ void MLIRCodeGen::visit(ReturnStatement *returnStmt) { } void MLIRCodeGen::visit(BreakStatement *breakStmt) { - auto type = builder.getIntegerType(1); - mlir::Value constTrue = builder.create(builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, true))); - builder.create(builder.getUnknownLoc(), breakGate, constTrue); + setBoolVar(breakStack.back(), true); breakDetected = true; } void MLIRCodeGen::visit(ContinueStatement *continueStmt) { + setBoolVar(continueStack.back(), true); + continueDetected = true; +} +void MLIRCodeGen::setBoolVar(mlir::spirv::VariableOp var, bool val) { + auto type = builder.getIntegerType(1); + mlir::Value constTrue = builder.create(builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, val))); + builder.create(builder.getUnknownLoc(), var, constTrue); } void MLIRCodeGen::visit(DiscardStatement *discardStmt) { @@ -1211,8 +1216,14 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E mlir::Type boolType = mlir::IntegerType::get(&context, 1, mlir::IntegerType::Signless); spirv::PointerType ptrType = spirv::PointerType::get(boolType, mlir::spirv::StorageClass::Function); - breakGate = builder.create( - builder.getUnknownLoc(), ptrType, spirv::StorageClass::Function, nullptr); + breakStack.push_back( + builder.create( + builder.getUnknownLoc(), ptrType, spirv::StorageClass::Function, nullptr) + ); + continueStack.push_back( + builder.create( + builder.getUnknownLoc(), ptrType, spirv::StorageClass::Function, nullptr) + ); if (initStmt) { initStmt->accept(this); @@ -1229,7 +1240,11 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E // Insert the body. Block *body = new Block(); loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), 2), body); - int i = 2; + + // Insert the continue block. + Block *continueBlock = new Block(); + loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), 3), continueBlock); + // Emit the entry code. Block *entry = loopOp.getEntryBlock(); builder.setInsertionPointToEnd(entry); @@ -1245,31 +1260,61 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E builder.setInsertionPointToStart(body); - // Detect break flag + // Detect break/continue flag + int postGateBlockInsertionPoint = 2; + if (auto body = dynamic_cast(bodyStmt)) { for (auto &stmt : body->getStatements()) { stmt->accept(this); - if (breakDetected) { - Block *postBreakBlock = new Block(); - loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++i), postBreakBlock); - builder.create(loc, load(breakGate), merge, ArrayRef(), postBreakBlock, ArrayRef()); - builder.setInsertionPointToStart(postBreakBlock); + if (breakDetected || continueDetected) { + Block *postGateBlock = new Block(); + loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock); + + if (breakDetected && continueDetected) { + Block *breakCheckBlock = new Block(); + auto continueGate = continueStack.back(); + auto breakGate = breakStack.back(); + loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), postGateBlockInsertionPoint), breakCheckBlock); + builder.create(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef(), breakCheckBlock, ArrayRef()); + builder.setInsertionPointToStart(breakCheckBlock); + builder.create(loc, load(breakGate), merge, ArrayRef(), postGateBlock, ArrayRef()); + } else if (continueDetected) { + auto continueGate = continueStack.back(); + builder.create(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef(), postGateBlock, ArrayRef()); + } else if (breakDetected) { + auto breakGate = breakStack.back(); + builder.create(loc, load(breakGate), merge, ArrayRef(), postGateBlock, ArrayRef()); + } + + builder.setInsertionPointToStart(postGateBlock); + + if (breakDetected) { + setBoolVar(breakStack.back(), false); + } + + if (continueDetected) { + setBoolVar(continueStack.back(), false); + } breakDetected = false; + continueDetected = false; } } } else { bodyStmt->accept(this); } + builder.create(loc, loopOp.getContinueBlock()); + builder.setInsertionPointToEnd(loopOp.getContinueBlock()); + if (inductionExpr) { inductionExpr->accept(this); } - Block *continueBlock = loopOp.getContinueBlock(); - builder.setInsertionPointToEnd(continueBlock); builder.create(loc, header); builder.setInsertionPointToEnd(restoreInsertionBlock); + breakStack.pop_back(); + continueStack.pop_back(); } mlir::Value MLIRCodeGen::load(mlir::Value val) { diff --git a/test/CodeGen/cf_loops_for.glsl b/test/CodeGen/cf_loops_for.glsl index ffdf981..9e573a0 100644 --- a/test/CodeGen/cf_loops_for.glsl +++ b/test/CodeGen/cf_loops_for.glsl @@ -1,36 +1,29 @@ void main() { - // CHECK: %cst0_si32 = spirv.Constant 0 : si32 - // CHECK-NEXT: %0 = spirv.Variable : !spirv.ptr - // CHECK-NEXT: spirv.Store "Function" %0, %cst0_si32 : si32 - // CHECK-NEXT: spirv.mlir.loop { - // CHECK-NEXT: spirv.Branch ^bb1 - // CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb2 + // CHECK: %2 = spirv.Variable : !spirv.ptr + // CHECK: spirv.mlir.loop { + // CHECK-NEXT: spirv.Branch ^bb1 + // CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb3 // CHECK-NEXT: %cst10_si32 = spirv.Constant 10 : si32 - // CHECK-NEXT: %2 = spirv.Load "Function" %0 : si32 - // CHECK-NEXT: %3 = spirv.SLessThan %2, %cst10_si32 : si32 - // CHECK-NEXT: spirv.BranchConditional %3, ^bb2, ^bb3 - // CHECK-NEXT: ^bb2: // pred: ^bb1 + // CHECK-NEXT: %3 = spirv.Load "Function" %2 : si32 + // CHECK-NEXT: %4 = spirv.SLessThan %3, %cst10_si32 : si32 + // CHECK-NEXT: spirv.BranchConditional %4, ^bb2, ^bb4 + // CHECK-NEXT:^bb2: // pred: ^bb1 // CHECK-NEXT: %cst1_si32 = spirv.Constant 1 : si32 - // CHECK-NEXT: %4 = spirv.Load "Function" %0 : si32 - // CHECK-NEXT: %5 = spirv.IAdd %4, %cst1_si32 : si32 - // CHECK-NEXT: %6 = spirv.Variable : !spirv.ptr - // CHECK-NEXT: spirv.Store "Function" %6, %5 : si32 - // CHECK-NEXT: %7 = spirv.Load "Function" %0 : si32 - // CHECK-NEXT: %cst1_si32_1 = spirv.Constant 1 : si32 - // CHECK-NEXT: %8 = spirv.IAdd %7, %cst1_si32_1 : si32 - // CHECK-NEXT: spirv.Store "Function" %0, %8 : si32 + // CHECK-NEXT: %5 = spirv.Load "Function" %2 : si32 + // CHECK-NEXT: %6 = spirv.IAdd %5, %cst1_si32 : si32 + // CHECK-NEXT: %7 = spirv.Variable : !spirv.ptr + // CHECK-NEXT: spirv.Store "Function" %7, %6 : si32 + // CHECK-NEXT: spirv.Branch ^bb3 + // CHECK-NEXT:^bb3: // pred: ^bb2 + // CHECK-NEXT: %8 = spirv.Load "Function" %2 : si32 + // CHECK-NEXT: %cst1_si32_0 = spirv.Constant 1 : si32 + // CHECK-NEXT: %9 = spirv.IAdd %8, %cst1_si32_0 : si32 + // CHECK-NEXT: spirv.Store "Function" %2, %9 : si32 // CHECK-NEXT: spirv.Branch ^bb1 - // CHECK-NEXT: ^bb3: // pred: ^bb1 + // CHECK-NEXT:^bb4: // pred: ^bb1 // CHECK-NEXT: spirv.mlir.merge - // CHECK-NEXT: } + // CHECK-NEXT:} for (int i = 0; i < 10; ++i) { int a = i + 1; } - - // TODO: file check embedded loops - for (int i = 0; i < 10; ++i) { - for (int j = 0; j < 20; ++j) { - int a = i + j; - } - } } \ No newline at end of file diff --git a/test/CodeGen/cf_loops_while.glsl b/test/CodeGen/cf_loops_while.glsl index e065e61..567429f 100644 --- a/test/CodeGen/cf_loops_while.glsl +++ b/test/CodeGen/cf_loops_while.glsl @@ -4,37 +4,22 @@ void main() { // CHECK: spirv.mlir.loop { // CHECK-NEXT: spirv.Branch ^bb1 - // CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb2 - // CHECK-NEXT: %2 = spirv.Load "Function" %0 : i1 - // CHECK-NEXT: spirv.BranchConditional %2, ^bb2, ^bb3 + // CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb3 + // CHECK-NEXT: %4 = spirv.Load "Function" %0 : i1 + // CHECK-NEXT: spirv.BranchConditional %4, ^bb2, ^bb4 // CHECK-NEXT: ^bb2: // pred: ^bb1 while (a) { int c = 2; int d = 3; - if (!a) { - return; - } - - if (a) { - int test = 2; - - if (!a) { - break; - } - } - - int blabla = 4; - - if (!a) { - break; - } - // CHECK: spirv.Store "Function" %1, %7 : si32 - // CHECK-NEXT: spirv.Branch ^bb1 + // CHECK: spirv.Store "Function" %1, %9 : si32 + // CHECK-NEXT: spirv.Branch ^bb3 b = c + d; } - // CHECK: ^bb3: // pred: ^bb1 - // CHECK-NEXT: spirv.mlir.merge + // CHECK-NEXT: ^bb3: // pred: ^bb2 + // CHECK-NEXT: spirv.Branch ^bb1 + // CHECK-NEXT: ^bb4: // pred: ^bb1 + // CHECK-NEXT: spirv.mlir.merge // CHECK-NEXT: } } diff --git a/test/CodeGen/cf_loops_while_break.glsl b/test/CodeGen/cf_loops_while_break.glsl new file mode 100644 index 0000000..d3e7e03 --- /dev/null +++ b/test/CodeGen/cf_loops_while_break.glsl @@ -0,0 +1,38 @@ +void main() { + // Hidden break/continue control vars + + // CHECK: %0 = spirv.Variable : !spirv.ptr + // CHECK-NEXT: %1 = spirv.Variable : !spirv.ptr + while (true) { + // CHECK: %cst1_si32 = spirv.Constant 1 : si32 + // CHECK-NEXT: %2 = spirv.Variable : !spirv.ptr + // CHECK-NEXT: spirv.Store "Function" %2, %cst1_si32 : si32 + int someVarBefore = 1; + + // CHECK: ^bb1: // pred: ^bb0 + // CHECK-NEXT: %true_2 = spirv.Constant true + // CHECK-NEXT: spirv.Store "Function" %0, %true_2 : i1 + if (true) { + break; + } + + // CHECK: spirv.mlir.merge + // CHECK-NEXT: } + // CHECK-NEXT: %3 = spirv.Load "Function" %0 : i1 + // CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb3 + + // CHECK: ^bb3: // pred: ^bb2 + // CHECK-NEXT: %false = spirv.Constant false + + // Reset break control var + + // CHECK-NEXT: spirv.Store "Function" %0, %false : i1 + // CHECK-NEXT: %cst1_si32_1 = spirv.Constant 1 : si32 + // CHECK-NEXT: %4 = spirv.Variable : !spirv.ptr + int someVarAfter = 1; + } + + // CHECK: ^bb5: // 2 preds: ^bb1, ^bb2 + // CHECK-NEXT: spirv.mlir.merge + // CHECK-NEXT: } +} diff --git a/test/CodeGen/cf_loops_while_continue.glsl b/test/CodeGen/cf_loops_while_continue.glsl new file mode 100644 index 0000000..5b8f376 --- /dev/null +++ b/test/CodeGen/cf_loops_while_continue.glsl @@ -0,0 +1,45 @@ +void main() { + // Hidden break/continue control vars + + // CHECK: %0 = spirv.Variable : !spirv.ptr + // CHECK-NEXT: %1 = spirv.Variable : !spirv.ptr + while (true) { + // CHECK: %cst1_si32 = spirv.Constant 1 : si32 + // CHECK-NEXT: %2 = spirv.Variable : !spirv.ptr + // CHECK-NEXT: spirv.Store "Function" %2, %cst1_si32 : si32 + int someVarBefore = 1; + + // CHECK: ^bb1: // pred: ^bb0 + // CHECK-NEXT: %true_3 = spirv.Constant true + // CHECK-NEXT: spirv.Store "Function" %1, %true_3 : i1 + if (true) { + continue; + // CHECK: ^bb2: // pred: ^bb0 + // CHECK-NEXT: %true_4 = spirv.Constant true + // CHECK-NEXT: spirv.Store "Function" %0, %true_4 : i1 + } else { + break; + } + + // CHECK: spirv.mlir.merge + // CHECK-NEXT: } + // CHECK-NEXT: %3 = spirv.Load "Function" %1 : i1 + // CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb4 + + // Reset continue/break control vars + // CHECK: ^bb3: // pred: ^bb4 + // CHECK-NEXT: %false = spirv.Constant false + // CHECK-NEXT: spirv.Store "Function" %0, %false : i1 + // CHECK-NEXT: %false_1 = spirv.Constant false + // CHECK-NEXT: spirv.Store "Function" %1, %false_1 : i1 + int someVarAfter = 1; + + // CHECK: ^bb4: // pred: ^bb2 + // CHECK-NEXT: %5 = spirv.Load "Function" %0 : i1 + // CHECK-NEXT: spirv.BranchConditional %5, ^bb6, ^bb3 + } + + // CHECK: ^bb6: // 2 preds: ^bb1, ^bb4 + // CHECK-NEXT: spirv.mlir.merge + // CHECK-NEXT: } +} diff --git a/test/CodeGen/scopes.glsl b/test/CodeGen/scopes.glsl index 7f2013b..e7c6b15 100644 --- a/test/CodeGen/scopes.glsl +++ b/test/CodeGen/scopes.glsl @@ -11,21 +11,21 @@ void main() { // CHECK: %1 = spirv.Load "Function" %0 : si32 // CHECK-NEXT: %2 = spirv.IEqual %1, %cst1_si32 : si32 if (a == 1) { - // CHECK: %5 = spirv.Variable : !spirv.ptr + // CHECK: %7 = spirv.Variable : !spirv.ptr int a; // CHECK: %cst2_si32_1 = spirv.Constant 2 : si32 - // CHECK-NEXT: spirv.Store "Function" %5, %cst2_si32_1 : si32 + // CHECK-NEXT: spirv.Store "Function" %7, %cst2_si32_1 : si32 a = 2; } else { - // CHECK: %6 = spirv.Variable : !spirv.ptr + // CHECK: %8 = spirv.Variable : !spirv.ptr int a; - + // CHECK: %cst3_si32 = spirv.Constant 3 : si32 - // CHECK-NEXT: spirv.Store "Function" %6, %cst3_si32 : si32 + // CHECK-NEXT: spirv.Store "Function" %8, %cst3_si32 : si32 a = 3; } - + // CHECK: %cst2_si32 = spirv.Constant 2 : si32 // CHECK-NEXT: spirv.Store "Function" %0, %cst2_si32 : si32 a = 2; @@ -36,14 +36,15 @@ void main() { * */ - // CHECK: %5 = spirv.Load "Function" %0 : si32 - // CHECK-NEXT: %6 = spirv.IEqual %5, %cst1_si32_1 : si32 + // CHECK: %7 = spirv.Load "Function" %0 : si32 + // CHECK-NEXT: %8 = spirv.IEqual %7, %cst1_si32_1 : si32 while (a == 1) { - // CHECK: %7 = spirv.Variable : !spirv.ptr + // CHECK: %9 = spirv.Variable : !spirv.ptr int a; + // CHECK: %cst5_si32 = spirv.Constant 5 : si32 - // CHECK-NEXT: spirv.Store "Function" %7, %cst5_si32 : si32 + // CHECK-NEXT: spirv.Store "Function" %9, %cst5_si32 : si32 a = 5; } @@ -57,22 +58,21 @@ void main() { * */ - // CHECK: %3 = spirv.Load "Function" %0 : si32 - // CHECK-NEXT: %4 = spirv.IEqual %3, %cst1_si32_0 : si32 + // CHECK: %5 = spirv.Load "Function" %0 : si32 + // CHECK-NEXT: %6 = spirv.IEqual %5, %cst1_si32_0 : si32 if (a == 1) { - // CHECK: %5 = spirv.Variable : !spirv.ptr + // CHECK: %7 = spirv.Variable : !spirv.ptr // CHECK-NEXT: %cst1_si32_1 = spirv.Constant 1 : si32 - // CHECK-NEXT: spirv.Store "Function" %5, %cst1_si32_1 : si32 + // CHECK-NEXT: spirv.Store "Function" %7, %cst1_si32_1 : si32 int a; a = 1; - - // CHECK: %6 = spirv.Load "Function" %5 : si32 - // CHECK-NEXT: %7 = spirv.IEqual %6, %cst2_si32_2 : si32 + // CHECK: %8 = spirv.Load "Function" %7 : si32 + // CHECK-NEXT: %9 = spirv.IEqual %8, %cst2_si32_2 : si32 if (a == 2) { - // CHECK: %8 = spirv.Variable : !spirv.ptr + // CHECK: %10 = spirv.Variable : !spirv.ptr // CHECK-NEXT: %cst2_si32_3 = spirv.Constant 2 : si32 - // CHECK-NEXT: spirv.Store "Function" %8, %cst2_si32_3 : si32 + // CHECK-NEXT: spirv.Store "Function" %10, %cst2_si32_3 : si32 int a; a = 2; }