Skip to content

Commit

Permalink
add user data to EnzymeRegisterCallHandler and add to header
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jan 16, 2025
1 parent 5651636 commit a801448
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
8 changes: 4 additions & 4 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
};
}

void EnzymeRegisterCallHandler(char *Name,
void EnzymeRegisterCallHandler(const char *Name,
CustomAugmentedFunctionForward FwdHandle,
CustomFunctionReverse RevHandle) {
CustomFunctionReverse RevHandle, void *data) {
auto &pair = customCallHandlers[Name];
pair.first = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,
Value *&normalReturn, Value *&shadowReturn,
Expand All @@ -351,15 +351,15 @@ void EnzymeRegisterCallHandler(char *Name,
LLVMValueRef shadowR = wrap(shadowReturn);
LLVMValueRef tapeR = wrap(tape);
uint8_t noMod =
FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR);
FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR, data);
normalReturn = unwrap(normalR);
shadowReturn = unwrap(shadowR);
tape = unwrap(tapeR);
return noMod != 0;
};
pair.second = [=](IRBuilder<> &B, CallInst *CI, DiffeGradientUtils &gutils,
Value *tape) {
RevHandle(wrap(&B), wrap(CI), &gutils, wrap(tape));
RevHandle(wrap(&B), wrap(CI), &gutils, wrap(tape), data);
};
}

Expand Down
13 changes: 11 additions & 2 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,11 @@ typedef uint8_t (*CustomAugmentedFunctionForward)(LLVMBuilderRef, LLVMValueRef,
GradientUtils *,
LLVMValueRef *,
LLVMValueRef *,
LLVMValueRef *);
LLVMValueRef *,
void *);

typedef void (*CustomFunctionReverse)(LLVMBuilderRef, LLVMValueRef,
DiffeGradientUtils *, LLVMValueRef);
DiffeGradientUtils *, LLVMValueRef, void *);

LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
Expand All @@ -222,6 +223,14 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
uint8_t *_overwritten_args, size_t overwritten_args_size,
EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd);

void EnzymeRegisterCallHandler(const char *Name,
CustomAugmentedFunctionForward FwdHandle,
CustomFunctionReverse RevHandle,
void *data);

LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils,
LLVMValueRef val);

#ifdef __cplusplus
}
#endif
Expand Down

0 comments on commit a801448

Please sign in to comment.