Skip to content

Commit

Permalink
Declare function arguments in symbol table, add function codegen tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed Aug 25, 2024
1 parent d200ffe commit 4077b86
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
4 changes: 3 additions & 1 deletion include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ struct SymbolTableEntry {
VariableDeclaration* variable = nullptr;
spirv::PointerType ptrType = nullptr;
bool isGlobal = false;
bool isFunctionParam = false;
Type* type = nullptr;
};

class MLIRCodeGen : public ASTVisitor {
Expand Down Expand Up @@ -92,7 +94,7 @@ class MLIRCodeGen : public ASTVisitor {
SymbolTableScopeT globalScope;
SmallVector<Attribute, 4> interface;

void declare(SymbolTableEntry);
void declare(StringRef name, SymbolTableEntry entry);
void createVariable(shaderpulse::Type *, VariableDeclaration *);
void insertEntryPoint();

Expand Down
24 changes: 16 additions & 8 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,12 @@ void MLIRCodeGen::visit(UnaryExpression *unExp) {
}
}

void MLIRCodeGen::declare(SymbolTableEntry entry) {
if (symbolTable.count(entry.variable->getIdentifierName()))
void MLIRCodeGen::declare(StringRef name, SymbolTableEntry entry) {
if (symbolTable.count(name)) {
return;
}

symbolTable.insert(entry.variable->getIdentifierName(), entry);
symbolTable.insert(name, entry);
}

void MLIRCodeGen::visit(VariableDeclarationList *varDeclList) {
Expand Down Expand Up @@ -335,7 +336,7 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type,
varOp->setAttr("location", *locationOpt);
}

declare({ mlir::Value(), varDecl, ptrType, /*isGlobal*/ true});
declare(varDecl->getIdentifierName(), { mlir::Value(), varDecl, ptrType, /*isGlobal*/ true});
// Set OpDecorate through attributes
// example:
// varOp->setAttr(spirv::stringifyDecoration(spirv::Decoration::Invariant),
Expand All @@ -361,7 +362,7 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type,
builder.create<spirv::StoreOp>(builder.getUnknownLoc(), var, val);
}

declare({ var, varDecl, nullptr, /*isGlobal*/ false});
declare(varDecl->getIdentifierName(), { var, varDecl, nullptr, /*isGlobal*/ false});
}
}

Expand Down Expand Up @@ -649,7 +650,9 @@ void MLIRCodeGen::visit(CallExpression *callExp) {
void MLIRCodeGen::visit(VariableExpression *varExp) {
auto entry = symbolTable.lookup(varExp->getName());

if (entry.variable) {
if (entry.isFunctionParam) {
expressionStack.push_back(std::make_pair(entry.type, entry.value));
} else if (entry.variable) {
Value val;

if (entry.isGlobal) {
Expand Down Expand Up @@ -743,7 +746,6 @@ void MLIRCodeGen::visit(DiscardStatement *discardStmt) {}

void MLIRCodeGen::visit(FunctionDeclaration *funcDecl) {
insideEntryPoint = funcDecl->getName() == "main";

SymbolTableScopeT varScope(symbolTable);
std::vector<mlir::Type> paramTypes;

Expand All @@ -759,7 +761,7 @@ void MLIRCodeGen::visit(FunctionDeclaration *funcDecl) {
}
}

auto funcOp = builder.create<spirv::FuncOp>(
spirv::FuncOp funcOp = builder.create<spirv::FuncOp>(
builder.getUnknownLoc(), funcDecl->getName(),
builder.getFunctionType(
paramTypes,
Expand All @@ -772,6 +774,12 @@ void MLIRCodeGen::visit(FunctionDeclaration *funcDecl) {
auto entryBlock = funcOp.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);

// Declare params as variables in the current scope
for (int i = 0; i < funcDecl->getParams().size(); i++) {
auto &param = funcDecl->getParams()[i];
declare(param->getName(), {funcOp.getArgument(i), nullptr, nullptr, false, true, param->getType()});
}

funcDecl->getBody()->accept(this);

// Return insertion for void functions
Expand Down
17 changes: 17 additions & 0 deletions test/CodeGen/functions.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// CHECK: spirv.func @add(%arg0: si32, %arg1: si32) -> si32 "None" {
// CHECK-NEXT: %0 = spirv.IAdd %arg0, %arg1 : si32
// CHECK-NEXT: spirv.ReturnValue %0 : si32
int add(int a, int b) {
return a + b;
}

// CHECK: spirv.func @main() "None" {
void main() {
int a = 1;
int b = 2;

// CHECK: %2 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %3 = spirv.Load "Function" %1 : si32
// CHECK-NEXT: %4 = spirv.FunctionCall @add(%2, %3) : (si32, si32) -> si32
int c = add(a, b);
}

0 comments on commit 4077b86

Please sign in to comment.