diff --git a/lib/CodeGen/MLIRCodeGen.cpp b/lib/CodeGen/MLIRCodeGen.cpp index b45b983..fa3dc57 100644 --- a/lib/CodeGen/MLIRCodeGen.cpp +++ b/lib/CodeGen/MLIRCodeGen.cpp @@ -1268,27 +1268,31 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E stmt->accept(this); if (breakDetected || continueDetected) { - Block *postGateBlock = new Block(); - loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock); - if (breakDetected && continueDetected) { - Block *breakCheckBlock = new Block(); auto continueGate = continueStack.back(); auto breakGate = breakStack.back(); - loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), postGateBlockInsertionPoint), breakCheckBlock); + Block *breakCheckBlock = new Block(); + loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), breakCheckBlock); builder.create(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef(), breakCheckBlock, ArrayRef()); + Block *postGateBlock = new Block(); + loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock); builder.setInsertionPointToStart(breakCheckBlock); builder.create(loc, load(breakGate), merge, ArrayRef(), postGateBlock, ArrayRef()); + builder.setInsertionPointToStart(postGateBlock); } else if (continueDetected) { auto continueGate = continueStack.back(); + Block *postGateBlock = new Block(); + loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock); builder.create(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef(), postGateBlock, ArrayRef()); + builder.setInsertionPointToStart(postGateBlock); } else if (breakDetected) { auto breakGate = breakStack.back(); + Block *postGateBlock = new Block(); + loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock); builder.create(loc, load(breakGate), merge, ArrayRef(), postGateBlock, ArrayRef()); + builder.setInsertionPointToStart(postGateBlock); } - builder.setInsertionPointToStart(postGateBlock); - if (breakDetected) { setBoolVar(breakStack.back(), false); } diff --git a/test/CodeGen/cf_loops_while_continue.glsl b/test/CodeGen/cf_loops_while_continue.glsl index 5b8f376..da83235 100644 --- a/test/CodeGen/cf_loops_while_continue.glsl +++ b/test/CodeGen/cf_loops_while_continue.glsl @@ -24,22 +24,21 @@ void main() { // CHECK: spirv.mlir.merge // CHECK-NEXT: } // CHECK-NEXT: %3 = spirv.Load "Function" %1 : i1 - // CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb4 + // CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb3 + // CHECK-NEXT: ^bb3: // pred: ^bb2 + // CHECK-NEXT: %4 = spirv.Load "Function" %0 : i1 + // CHECK-NEXT: spirv.BranchConditional %4, ^bb6, ^bb4 // Reset continue/break control vars - // CHECK: ^bb3: // pred: ^bb4 + // CHECK: ^bb4: // pred: ^bb3 // CHECK-NEXT: %false = spirv.Constant false // CHECK-NEXT: spirv.Store "Function" %0, %false : i1 // CHECK-NEXT: %false_1 = spirv.Constant false // CHECK-NEXT: spirv.Store "Function" %1, %false_1 : i1 int someVarAfter = 1; - - // CHECK: ^bb4: // pred: ^bb2 - // CHECK-NEXT: %5 = spirv.Load "Function" %0 : i1 - // CHECK-NEXT: spirv.BranchConditional %5, ^bb6, ^bb3 } - // CHECK: ^bb6: // 2 preds: ^bb1, ^bb4 - // CHECK-NEXT: spirv.mlir.merge + // CHECK: ^bb6: // 2 preds: ^bb1, ^bb3 + // CHECK-NEXT: spirv.mlir.merge // CHECK-NEXT: } }