diff --git a/.gitignore b/.gitignore index 009fe3b..d4089e0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ build .vscode test/CodeGen/*.mlir -example/*.mlir \ No newline at end of file +example/*.mlir +*.spv \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e2e582..1db5df4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,6 +84,7 @@ if(ENABLE_CODEGEN) MLIRLLVMToLLVMIRTranslation MLIRMemRefDialect MLIRSPIRVDialect + MLIRSPIRVSerialization MLIRParser MLIRPass MLIRSideEffectInterfaces diff --git a/compiler/shaderpulse.cpp b/compiler/shaderpulse.cpp index b6a508b..e89033a 100644 --- a/compiler/shaderpulse.cpp +++ b/compiler/shaderpulse.cpp @@ -84,13 +84,21 @@ int main(int argc, char** argv) { mlirCodeGen.print(); if (!mlirCodeGen.verify()) { - std::cout << "Error verifying the SPIR-V module" << std::endl; - return -1; + std::cout << "Error verifying the SPIR-V module." << std::endl; + return -1; + } + + bool success = mlirCodeGen.saveToFile(outputPath); + + if (!success) { + std::cout << "Failed to save spirv mlir to file."; + return -1; } - bool succes = mlirCodeGen.saveToFile(outputPath); + success = mlirCodeGen.emitSpirv(outputPath.replace_extension(".spv")); - if (!succes) { + if (!success) { + std::cout << "Failed to emit spirv binary."; return -1; } } diff --git a/example/square.glsl b/example/square.glsl new file mode 100644 index 0000000..2ec47fb --- /dev/null +++ b/example/square.glsl @@ -0,0 +1,14 @@ +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) buffer InBuffer { + int data[]; +}; + +layout(binding = 1) buffer OutBuffer { + int result[]; +}; + +void main() { + uint idx = gl_GlobalInvocationID.x; + result[idx] = data[idx] * data[idx]; +} diff --git a/include/CodeGen/MLIRCodeGen.h b/include/CodeGen/MLIRCodeGen.h index f020ee6..3b69b35 100644 --- a/include/CodeGen/MLIRCodeGen.h +++ b/include/CodeGen/MLIRCodeGen.h @@ -42,6 +42,7 @@ class MLIRCodeGen : public ASTVisitor { void initModuleOp(); void print(); bool saveToFile(const std::filesystem::path& outputPath); + bool emitSpirv(const std::filesystem::path& outputPath); bool verify(); void visit(TranslationUnit *) override; void visit(BinaryExpression *) override; diff --git a/lib/CodeGen/MLIRCodeGen.cpp b/lib/CodeGen/MLIRCodeGen.cpp index d7555b2..5243856 100644 --- a/lib/CodeGen/MLIRCodeGen.cpp +++ b/lib/CodeGen/MLIRCodeGen.cpp @@ -5,6 +5,7 @@ #include #include #include +#include "mlir/Target/SPIRV/Serialization.h" #include #include #include @@ -213,6 +214,19 @@ bool MLIRCodeGen::saveToFile(const std::filesystem::path& outputPath) { return true; } +bool MLIRCodeGen::emitSpirv(const std::filesystem::path& outputPath) { + llvm::SmallVector spirvBinary; + mlir::LogicalResult result = mlir::spirv::serialize(spirvModule, spirvBinary); + if (failed(result)) { + std::cerr << "Failed to serialize SPIR-V module." << std::endl; + return false; + } + + std::ofstream outFile(outputPath, std::ios::binary); + outFile.write(reinterpret_cast(spirvBinary.data()), spirvBinary.size() * sizeof(uint32_t)); + return true; +} + bool MLIRCodeGen::verify() { return !failed(mlir::verify(spirvModule)); } void MLIRCodeGen::insertEntryPoint() {