diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 1257467edb6e..8d1f4ec85808 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -1586,7 +1586,8 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) { T = FT->getParamType(i)->getPointerElementType(); #endif IRBuilder<> EB(&NewF->getEntryBlock().front()); - arg->replaceAllUsesWith(EB.CreateAlloca(T)); + auto AL = EB.CreateAlloca(T, 0, "stack_roots"); + arg->replaceAllUsesWith(AL); delete arg; } for (auto i : rroots_v) { @@ -1604,7 +1605,8 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) { IRBuilder<> EB(&NewF->getEntryBlock().front()); Value *val = UndefValue::get(AT); for (size_t j = 0; j < AT->getNumElements(); j++) { - val = EB.CreateInsertValue(val, EB.CreateAlloca(T), j); + auto AL = EB.CreateAlloca(T, 0, "stack_roots_v"); + val = EB.CreateInsertValue(val, AL, j); } arg->replaceAllUsesWith(val); delete arg; @@ -1621,7 +1623,7 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) { size_t nexti = 0; Value *sret = nullptr; if (sretTy) { - sret = EB.CreateAlloca(sretTy); + sret = EB.CreateAlloca(sretTy, 0, "stack_sret"); vals.push_back(sret); NewAttrs = NewAttrs.addAttribute( F->getContext(), AttributeList::FirstArgIndex + nexti, @@ -1630,7 +1632,7 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) { } AllocaInst *roots = nullptr; if (roots_AT) { - roots = EB.CreateAlloca(roots_AT); + roots = EB.CreateAlloca(roots_AT, 0, "stack_roots_AT"); vals.push_back(roots); NewAttrs = NewAttrs.addAttribute( @@ -1675,21 +1677,66 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) { sretCount++; } + std::function, int, Type *)> + copyNonJLValue = [&](Type *curType, Value *out, Value *in, + ArrayRef inds, int sretCount, Type *ptrTy) { + if (auto PT = dyn_cast(curType)) { + if (PT->getAddressSpace() == 10) { + return; + } + } + + if (auto AT = dyn_cast(curType)) { + for (size_t i = 0; i < AT->getNumElements(); i++) { + SmallVector next(inds.begin(), inds.end()); + next.push_back(i); + copyNonJLValue(AT->getElementType(), out, in, next, sretCount, + ptrTy); + } + return; + } + if (auto ST = dyn_cast(curType)) { + for (size_t i = 0; i < ST->getNumElements(); i++) { + SmallVector next(inds.begin(), inds.end()); + next.push_back(i); + copyNonJLValue(ST->getElementType(i), out, in, next, sretCount, + ptrTy); + } + return; + } + + SmallVector ininds; + SmallVector outinds; + auto c0 = ConstantInt::get(B.getInt64Ty(), 0); + ininds.push_back(c0); + outinds.push_back(c0); + if (sretCount >= 0) + outinds.push_back(ConstantInt::get(B.getInt32Ty(), sretCount)); + for (auto v : inds) { + ininds.push_back(ConstantInt::get(B.getInt32Ty(), v)); + outinds.push_back(ConstantInt::get(B.getInt32Ty(), v)); + } + + if (outinds.size() > 1) + out = B.CreateInBoundsGEP(sretTy, out, outinds); + if (ininds.size() > 1) + in = B.CreateInBoundsGEP(ptrTy, in, ininds); + + auto ld = B.CreateLoad(curType, in); + B.CreateStore(ld, out); + }; + for (Value *ptr : sret_vals) { - auto gep = - ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret; - auto ld = B.CreateLoad(Types[sretCount], ptr); - B.CreateStore(ld, gep); + copyNonJLValue(Types[sretCount], sret, ptr, {}, ST ? sretCount : -1, + Types[sretCount]); sretCount++; } for (Value *ptr_v : sretv_vals) { auto AT = cast(ptr_v->getType()); for (size_t j = 0; j < AT->getNumElements(); j++) { - auto gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount + j) - : sret; auto ptr = GradientUtils::extractMeta(B, ptr_v, j); - auto ld = B.CreateLoad(Types[sretCount], ptr); - B.CreateStore(ld, gep); + copyNonJLValue(Types[sretCount], sret, ptr, {}, + ST ? (sretCount + j) : -1, Types[sretCount]); } sretCount += AT->getNumElements(); }