Skip to content

Commit

Permalink
Fix lower kernel global bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 9, 2025
1 parent d222454 commit dce4301
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
if (auto op2 = cop.resolveCallable())
tocopy.push_back(op2);
});
op->walk([&](LLVM::AddressOfOp cop) {
cur->walk([&](LLVM::AddressOfOp cop) {
if (auto op2 = cop.getGlobal(symbolTable))
tocopy.push_back(op2);
else if (auto op2 = cop.getFunction(symbolTable))
Expand Down Expand Up @@ -507,7 +507,10 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
options.hostUseBarePtrCallConv = false;
buildLowerToNVVMPassPipeline(pm, options, toolkitPath, linkFiles);

pm.run(submod);
auto subres = pm.run(submod);
if (!subres.succeeded()) {
return {};
}

OpBuilder builder(submod);

Expand Down

0 comments on commit dce4301

Please sign in to comment.