-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
frule_via_ad
should accept several arguments
#12
Comments
@mohamed82008 any idea? |
Not a fan of supporting multiple arguments with ForwardDiff. ForwardDiff's API does not support that. |
hmm on second thought, I changed my mind |
Given the metaprogramming-heavy nature of this package, I'm not sure how to make a PR implementing this |
NonconvexUtils already supports this https://github.com/JuliaNonconvex/NonconvexUtils.jl/blob/main/src/forwarddiff_frule.jl#L1. You can just assume the inputs are all real/arrays and not recursive containers and simplify the implementation. I use flatten/unflatten but that's because I wanted to be too generic. No need to be that generic. |
According to the official API specification of ChainRulesCore.jl,
frule_via_ad
should accept all the arguments of the function as a destructured tuple: https://juliadiff.org/ChainRulesCore.jl/stable/api.html#ChainRulesCore.frule_via_adHowever, it seems that ForwardDiffChainRules.jl only accepts one argument:
ForwardDiffChainRules.jl/src/ForwardDiffChainRules.jl
Lines 46 to 52 in 609201f
I think this is the reason for a bug in my code. Do you think it is fixable?
The text was updated successfully, but these errors were encountered: