From a801448d670df477405826fcb25d6d7cb65ed19a Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Thu, 16 Jan 2025 21:57:29 +0000 Subject: [PATCH] add user data to EnzymeRegisterCallHandler and add to header --- enzyme/Enzyme/CApi.cpp | 8 ++++---- enzyme/Enzyme/CApi.h | 13 +++++++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index ca71867462f..f72fa3b4f4d 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -340,9 +340,9 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, }; } -void EnzymeRegisterCallHandler(char *Name, +void EnzymeRegisterCallHandler(const char *Name, CustomAugmentedFunctionForward FwdHandle, - CustomFunctionReverse RevHandle) { + CustomFunctionReverse RevHandle, void *data) { auto &pair = customCallHandlers[Name]; pair.first = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils, Value *&normalReturn, Value *&shadowReturn, @@ -351,7 +351,7 @@ void EnzymeRegisterCallHandler(char *Name, LLVMValueRef shadowR = wrap(shadowReturn); LLVMValueRef tapeR = wrap(tape); uint8_t noMod = - FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR); + FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR, data); normalReturn = unwrap(normalR); shadowReturn = unwrap(shadowR); tape = unwrap(tapeR); @@ -359,7 +359,7 @@ void EnzymeRegisterCallHandler(char *Name, }; pair.second = [=](IRBuilder<> &B, CallInst *CI, DiffeGradientUtils &gutils, Value *tape) { - RevHandle(wrap(&B), wrap(CI), &gutils, wrap(tape)); + RevHandle(wrap(&B), wrap(CI), &gutils, wrap(tape), data); }; } diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index 3a38a68c4c7..dc4296403d5 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -198,10 +198,11 @@ typedef uint8_t (*CustomAugmentedFunctionForward)(LLVMBuilderRef, LLVMValueRef, GradientUtils *, LLVMValueRef *, LLVMValueRef *, - LLVMValueRef *); + LLVMValueRef *, + void *); typedef void (*CustomFunctionReverse)(LLVMBuilderRef, LLVMValueRef, - DiffeGradientUtils *, LLVMValueRef); + DiffeGradientUtils *, LLVMValueRef, void *); LLVMValueRef EnzymeCreateForwardDiff( EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, @@ -222,6 +223,14 @@ LLVMValueRef EnzymeCreatePrimalAndGradient( uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd); +void EnzymeRegisterCallHandler(const char *Name, + CustomAugmentedFunctionForward FwdHandle, + CustomFunctionReverse RevHandle, + void *data); + +LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils, + LLVMValueRef val); + #ifdef __cplusplus } #endif