Skip to content

Commit

Permalink
Break/continue tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed Sep 27, 2024
1 parent 79d4cd9 commit 19f0f77
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 85 deletions.
5 changes: 4 additions & 1 deletion include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ class MLIRCodeGen : public ASTVisitor {
std::vector<mlir::Value> expressionStack;
StructDeclaration* currentBaseComposite = nullptr;
mlir::Operation *execModeOp = nullptr;
mlir::spirv::VariableOp breakGate;
std::vector<mlir::spirv::VariableOp> breakStack;
std::vector<mlir::spirv::VariableOp> continueStack;
bool breakDetected = false;
bool continueDetected = false;

llvm::ScopedHashTable<llvm::StringRef, SymbolTableEntry>
symbolTable;
Expand All @@ -109,6 +111,7 @@ class MLIRCodeGen : public ASTVisitor {
bool callBuiltIn(CallExpression* exp);
void createBuiltinComputeVar(const std::string &varName, const std::string &mlirName);
void generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt);
void setBoolVar(mlir::spirv::VariableOp var, bool val);
mlir::Value load(mlir::Value);
mlir::Value popExpressionStack();
mlir::Value currentBasePointer;
Expand Down
73 changes: 59 additions & 14 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1120,14 +1120,19 @@ void MLIRCodeGen::visit(ReturnStatement *returnStmt) {
}

void MLIRCodeGen::visit(BreakStatement *breakStmt) {
auto type = builder.getIntegerType(1);
mlir::Value constTrue = builder.create<spirv::ConstantOp>(builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, true)));
builder.create<spirv::StoreOp>(builder.getUnknownLoc(), breakGate, constTrue);
setBoolVar(breakStack.back(), true);
breakDetected = true;
}

void MLIRCodeGen::visit(ContinueStatement *continueStmt) {
setBoolVar(continueStack.back(), true);
continueDetected = true;
}

void MLIRCodeGen::setBoolVar(mlir::spirv::VariableOp var, bool val) {
auto type = builder.getIntegerType(1);
mlir::Value constTrue = builder.create<spirv::ConstantOp>(builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, val)));
builder.create<spirv::StoreOp>(builder.getUnknownLoc(), var, constTrue);
}

void MLIRCodeGen::visit(DiscardStatement *discardStmt) {
Expand Down Expand Up @@ -1211,8 +1216,14 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E

mlir::Type boolType = mlir::IntegerType::get(&context, 1, mlir::IntegerType::Signless);
spirv::PointerType ptrType = spirv::PointerType::get(boolType, mlir::spirv::StorageClass::Function);
breakGate = builder.create<spirv::VariableOp>(
builder.getUnknownLoc(), ptrType, spirv::StorageClass::Function, nullptr);
breakStack.push_back(
builder.create<spirv::VariableOp>(
builder.getUnknownLoc(), ptrType, spirv::StorageClass::Function, nullptr)
);
continueStack.push_back(
builder.create<spirv::VariableOp>(
builder.getUnknownLoc(), ptrType, spirv::StorageClass::Function, nullptr)
);

if (initStmt) {
initStmt->accept(this);
Expand All @@ -1229,7 +1240,11 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E
// Insert the body.
Block *body = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), 2), body);
int i = 2;

// Insert the continue block.
Block *continueBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), 3), continueBlock);

// Emit the entry code.
Block *entry = loopOp.getEntryBlock();
builder.setInsertionPointToEnd(entry);
Expand All @@ -1245,31 +1260,61 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E

builder.setInsertionPointToStart(body);

// Detect break flag
// Detect break/continue flag
int postGateBlockInsertionPoint = 2;

if (auto body = dynamic_cast<StatementList*>(bodyStmt)) {
for (auto &stmt : body->getStatements()) {
stmt->accept(this);

if (breakDetected) {
Block *postBreakBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++i), postBreakBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(breakGate), merge, ArrayRef<mlir::Value>(), postBreakBlock, ArrayRef<mlir::Value>());
builder.setInsertionPointToStart(postBreakBlock);
if (breakDetected || continueDetected) {
Block *postGateBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock);

if (breakDetected && continueDetected) {
Block *breakCheckBlock = new Block();
auto continueGate = continueStack.back();
auto breakGate = breakStack.back();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), postGateBlockInsertionPoint), breakCheckBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef<mlir::Value>(), breakCheckBlock, ArrayRef<mlir::Value>());
builder.setInsertionPointToStart(breakCheckBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(breakGate), merge, ArrayRef<mlir::Value>(), postGateBlock, ArrayRef<mlir::Value>());
} else if (continueDetected) {
auto continueGate = continueStack.back();
builder.create<spirv::BranchConditionalOp>(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef<mlir::Value>(), postGateBlock, ArrayRef<mlir::Value>());
} else if (breakDetected) {
auto breakGate = breakStack.back();
builder.create<spirv::BranchConditionalOp>(loc, load(breakGate), merge, ArrayRef<mlir::Value>(), postGateBlock, ArrayRef<mlir::Value>());
}

builder.setInsertionPointToStart(postGateBlock);

if (breakDetected) {
setBoolVar(breakStack.back(), false);
}

if (continueDetected) {
setBoolVar(continueStack.back(), false);
}
breakDetected = false;
continueDetected = false;
}
}
} else {
bodyStmt->accept(this);
}

builder.create<spirv::BranchOp>(loc, loopOp.getContinueBlock());
builder.setInsertionPointToEnd(loopOp.getContinueBlock());

if (inductionExpr) {
inductionExpr->accept(this);
}

Block *continueBlock = loopOp.getContinueBlock();
builder.setInsertionPointToEnd(continueBlock);
builder.create<spirv::BranchOp>(loc, header);
builder.setInsertionPointToEnd(restoreInsertionBlock);
breakStack.pop_back();
continueStack.pop_back();
}

mlir::Value MLIRCodeGen::load(mlir::Value val) {
Expand Down
47 changes: 20 additions & 27 deletions test/CodeGen/cf_loops_for.glsl
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
void main() {
// CHECK: %cst0_si32 = spirv.Constant 0 : si32
// CHECK-NEXT: %0 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %0, %cst0_si32 : si32
// CHECK-NEXT: spirv.mlir.loop {
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb2
// CHECK: %2 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK: spirv.mlir.loop {
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb3
// CHECK-NEXT: %cst10_si32 = spirv.Constant 10 : si32
// CHECK-NEXT: %2 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %3 = spirv.SLessThan %2, %cst10_si32 : si32
// CHECK-NEXT: spirv.BranchConditional %3, ^bb2, ^bb3
// CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %3 = spirv.Load "Function" %2 : si32
// CHECK-NEXT: %4 = spirv.SLessThan %3, %cst10_si32 : si32
// CHECK-NEXT: spirv.BranchConditional %4, ^bb2, ^bb4
// CHECK-NEXT:^bb2: // pred: ^bb1
// CHECK-NEXT: %cst1_si32 = spirv.Constant 1 : si32
// CHECK-NEXT: %4 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %5 = spirv.IAdd %4, %cst1_si32 : si32
// CHECK-NEXT: %6 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %6, %5 : si32
// CHECK-NEXT: %7 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %cst1_si32_1 = spirv.Constant 1 : si32
// CHECK-NEXT: %8 = spirv.IAdd %7, %cst1_si32_1 : si32
// CHECK-NEXT: spirv.Store "Function" %0, %8 : si32
// CHECK-NEXT: %5 = spirv.Load "Function" %2 : si32
// CHECK-NEXT: %6 = spirv.IAdd %5, %cst1_si32 : si32
// CHECK-NEXT: %7 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %7, %6 : si32
// CHECK-NEXT: spirv.Branch ^bb3
// CHECK-NEXT:^bb3: // pred: ^bb2
// CHECK-NEXT: %8 = spirv.Load "Function" %2 : si32
// CHECK-NEXT: %cst1_si32_0 = spirv.Constant 1 : si32
// CHECK-NEXT: %9 = spirv.IAdd %8, %cst1_si32_0 : si32
// CHECK-NEXT: spirv.Store "Function" %2, %9 : si32
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb3: // pred: ^bb1
// CHECK-NEXT:^bb4: // pred: ^bb1
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
// CHECK-NEXT:}
for (int i = 0; i < 10; ++i) {
int a = i + 1;
}

// TODO: file check embedded loops
for (int i = 0; i < 10; ++i) {
for (int j = 0; j < 20; ++j) {
int a = i + j;
}
}
}
33 changes: 9 additions & 24 deletions test/CodeGen/cf_loops_while.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,22 @@ void main() {

// CHECK: spirv.mlir.loop {
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb2
// CHECK-NEXT: %2 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %2, ^bb2, ^bb3
// CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb3
// CHECK-NEXT: %4 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %4, ^bb2, ^bb4
// CHECK-NEXT: ^bb2: // pred: ^bb1
while (a) {
int c = 2;
int d = 3;

if (!a) {
return;
}

if (a) {
int test = 2;

if (!a) {
break;
}
}

int blabla = 4;

if (!a) {
break;
}
// CHECK: spirv.Store "Function" %1, %7 : si32
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK: spirv.Store "Function" %1, %9 : si32
// CHECK-NEXT: spirv.Branch ^bb3
b = c + d;
}

// CHECK: ^bb3: // pred: ^bb1
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb4: // pred: ^bb1
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
}
38 changes: 38 additions & 0 deletions test/CodeGen/cf_loops_while_break.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
void main() {
// Hidden break/continue control vars

// CHECK: %0 = spirv.Variable : !spirv.ptr<i1, Function>
// CHECK-NEXT: %1 = spirv.Variable : !spirv.ptr<i1, Function>
while (true) {
// CHECK: %cst1_si32 = spirv.Constant 1 : si32
// CHECK-NEXT: %2 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %2, %cst1_si32 : si32
int someVarBefore = 1;

// CHECK: ^bb1: // pred: ^bb0
// CHECK-NEXT: %true_2 = spirv.Constant true
// CHECK-NEXT: spirv.Store "Function" %0, %true_2 : i1
if (true) {
break;
}

// CHECK: spirv.mlir.merge
// CHECK-NEXT: }
// CHECK-NEXT: %3 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb3

// CHECK: ^bb3: // pred: ^bb2
// CHECK-NEXT: %false = spirv.Constant false

// Reset break control var

// CHECK-NEXT: spirv.Store "Function" %0, %false : i1
// CHECK-NEXT: %cst1_si32_1 = spirv.Constant 1 : si32
// CHECK-NEXT: %4 = spirv.Variable : !spirv.ptr<si32, Function>
int someVarAfter = 1;
}

// CHECK: ^bb5: // 2 preds: ^bb1, ^bb2
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
}
45 changes: 45 additions & 0 deletions test/CodeGen/cf_loops_while_continue.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
void main() {
// Hidden break/continue control vars

// CHECK: %0 = spirv.Variable : !spirv.ptr<i1, Function>
// CHECK-NEXT: %1 = spirv.Variable : !spirv.ptr<i1, Function>
while (true) {
// CHECK: %cst1_si32 = spirv.Constant 1 : si32
// CHECK-NEXT: %2 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %2, %cst1_si32 : si32
int someVarBefore = 1;

// CHECK: ^bb1: // pred: ^bb0
// CHECK-NEXT: %true_3 = spirv.Constant true
// CHECK-NEXT: spirv.Store "Function" %1, %true_3 : i1
if (true) {
continue;
// CHECK: ^bb2: // pred: ^bb0
// CHECK-NEXT: %true_4 = spirv.Constant true
// CHECK-NEXT: spirv.Store "Function" %0, %true_4 : i1
} else {
break;
}

// CHECK: spirv.mlir.merge
// CHECK-NEXT: }
// CHECK-NEXT: %3 = spirv.Load "Function" %1 : i1
// CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb4

// Reset continue/break control vars
// CHECK: ^bb3: // pred: ^bb4
// CHECK-NEXT: %false = spirv.Constant false
// CHECK-NEXT: spirv.Store "Function" %0, %false : i1
// CHECK-NEXT: %false_1 = spirv.Constant false
// CHECK-NEXT: spirv.Store "Function" %1, %false_1 : i1
int someVarAfter = 1;

// CHECK: ^bb4: // pred: ^bb2
// CHECK-NEXT: %5 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %5, ^bb6, ^bb3
}

// CHECK: ^bb6: // 2 preds: ^bb1, ^bb4
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
}
Loading

0 comments on commit 19f0f77

Please sign in to comment.