Skip to content

Commit

Permalink
Implement do-while loops
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed Oct 9, 2024
1 parent 4b47402 commit a7f0832
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 8 deletions.
2 changes: 1 addition & 1 deletion include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class MLIRCodeGen : public ASTVisitor {
void initBuiltinFuncMap();
bool callBuiltIn(CallExpression* exp);
void createBuiltinComputeVar(const std::string &varName, const std::string &mlirName);
void generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt);
void generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt, bool isDoWhile = false);
void setBoolVar(mlir::spirv::VariableOp var, bool val);
mlir::Value load(mlir::Value);
mlir::Value popExpressionStack();
Expand Down
4 changes: 3 additions & 1 deletion lib/AST/PrinterASTVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,12 @@ void PrinterASTVisitor::visit(WhileStatement *whileStmt) {
void PrinterASTVisitor::visit(DoStatement *doStmt) {
print("|-DoStatement " + loc(doStmt->getSourceLocation()));

doStmt->getCondition()->accept(this);
indent();
doStmt->getBody()->accept(this);
resetIndent();
indent();
doStmt->getCondition()->accept(this);
resetIndent();
}

void PrinterASTVisitor::visit(IfStatement *ifStmt) {
Expand Down
24 changes: 18 additions & 6 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ void MLIRCodeGen::visit(InterfaceBlock *interfaceBlock) {
}

void MLIRCodeGen::visit(DoStatement *doStmt) {
// TODO: implement me
generateLoop(nullptr, doStmt->getCondition(), nullptr, doStmt->getBody(), /*isDoWhile*/ true);
}

void MLIRCodeGen::visit(IfStatement *ifStmt) {
Expand Down Expand Up @@ -1230,7 +1230,7 @@ void MLIRCodeGen::visit(DefaultLabel *defaultLabel) {}

void MLIRCodeGen::visit(CaseLabel *defaultLabel) {}

void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt) {
void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt, bool isDoWhile) {
Block *restoreInsertionBlock = builder.getInsertionBlock();
SymbolTableScopeT varScope(symbolTable);

Expand Down Expand Up @@ -1277,9 +1277,14 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E
builder.setInsertionPointToEnd(header);
Block *merge = loopOp.getMergeBlock();

conditionExpr->accept(this);
auto conditionOp = load(popExpressionStack());
builder.create<spirv::BranchConditionalOp>(loc, conditionOp, body, ArrayRef<mlir::Value>(), merge, ArrayRef<mlir::Value>());
if (isDoWhile) {
builder.create<spirv::BranchOp>(loc, &*std::next(loopOp.getBody().begin(), 2));
} else {
conditionExpr->accept(this);
auto conditionOp = load(popExpressionStack());
builder.create<spirv::BranchConditionalOp>(loc, conditionOp, body, ArrayRef<mlir::Value>(), merge, ArrayRef<mlir::Value>());
}


builder.setInsertionPointToStart(body);

Expand Down Expand Up @@ -1332,7 +1337,14 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E
inductionExpr->accept(this);
}

builder.create<spirv::BranchOp>(loc, header);
if (isDoWhile) {
conditionExpr->accept(this);
auto conditionOp = load(popExpressionStack());
builder.create<spirv::BranchConditionalOp>(loc, conditionOp, header, ArrayRef<mlir::Value>(), merge, ArrayRef<mlir::Value>());
} else {
builder.create<spirv::BranchOp>(loc, header);
}

builder.setInsertionPointToEnd(restoreInsertionBlock);
breakStack.pop_back();
continueStack.pop_back();
Expand Down
29 changes: 29 additions & 0 deletions test/CodeGen/cf_do_while.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
void main() {
bool a = true;
int b;

// Header jumps unconditionally to loop body

// CHECK: spirv.mlir.loop {
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb3
// CHECK-NEXT: spirv.Branch ^bb2
do {
int c = 2;
int d = 3;
b = c + d;

// The condition check happens in the continue block

// CHECK: ^bb3: // pred: ^bb2
// CHECK-NEXT: %false_1 = spirv.Constant false
// CHECK-NEXT: spirv.Store "Function" %3, %false_1 : i1
// CHECK-NEXT: %10 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %10, ^bb1, ^bb4
// CHECK-NEXT: ^bb4: // pred: ^bb3
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
} while(a);

int someVarAfter = 12;
}

0 comments on commit a7f0832

Please sign in to comment.