Skip to content

Commit

Permalink
Fix loop continue/break gate block ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed Sep 29, 2024
1 parent 19f0f77 commit c661820
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
18 changes: 11 additions & 7 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1268,27 +1268,31 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E
stmt->accept(this);

if (breakDetected || continueDetected) {
Block *postGateBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock);

if (breakDetected && continueDetected) {
Block *breakCheckBlock = new Block();
auto continueGate = continueStack.back();
auto breakGate = breakStack.back();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), postGateBlockInsertionPoint), breakCheckBlock);
Block *breakCheckBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), breakCheckBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef<mlir::Value>(), breakCheckBlock, ArrayRef<mlir::Value>());
Block *postGateBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock);
builder.setInsertionPointToStart(breakCheckBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(breakGate), merge, ArrayRef<mlir::Value>(), postGateBlock, ArrayRef<mlir::Value>());
builder.setInsertionPointToStart(postGateBlock);
} else if (continueDetected) {
auto continueGate = continueStack.back();
Block *postGateBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef<mlir::Value>(), postGateBlock, ArrayRef<mlir::Value>());
builder.setInsertionPointToStart(postGateBlock);
} else if (breakDetected) {
auto breakGate = breakStack.back();
Block *postGateBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(breakGate), merge, ArrayRef<mlir::Value>(), postGateBlock, ArrayRef<mlir::Value>());
builder.setInsertionPointToStart(postGateBlock);
}

builder.setInsertionPointToStart(postGateBlock);

if (breakDetected) {
setBoolVar(breakStack.back(), false);
}
Expand Down
15 changes: 7 additions & 8 deletions test/CodeGen/cf_loops_while_continue.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,21 @@ void main() {
// CHECK: spirv.mlir.merge
// CHECK-NEXT: }
// CHECK-NEXT: %3 = spirv.Load "Function" %1 : i1
// CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb4
// CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb3
// CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: %4 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %4, ^bb6, ^bb4

// Reset continue/break control vars
// CHECK: ^bb3: // pred: ^bb4
// CHECK: ^bb4: // pred: ^bb3
// CHECK-NEXT: %false = spirv.Constant false
// CHECK-NEXT: spirv.Store "Function" %0, %false : i1
// CHECK-NEXT: %false_1 = spirv.Constant false
// CHECK-NEXT: spirv.Store "Function" %1, %false_1 : i1
int someVarAfter = 1;

// CHECK: ^bb4: // pred: ^bb2
// CHECK-NEXT: %5 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %5, ^bb6, ^bb3
}

// CHECK: ^bb6: // 2 preds: ^bb1, ^bb4
// CHECK-NEXT: spirv.mlir.merge
// CHECK: ^bb6: // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
}

0 comments on commit c661820

Please sign in to comment.