-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Externalize gradient computations to DifferentiationInterface.jl? #544
Comments
Yeah, I am all for purging that code, depending on
That should be enough to get a start. |
Tracker works and is tested, ComponentArrays is my next target |
That's great. This week is a bit busy but I can try possibly early next month |
No rush! And always happy to help debug. |
Do you require a vector input mandatorily? ComponentArrays has an overhead for smallish arrays (see #49), so having an |
At the moment yes. We're thinking about how to be more flexible in order to accommodate Flux's needs, you can track JuliaDiff/DifferentiationInterface.jl#87 to see how it evolves |
DI v0.3 should be out sometime next week (I've been busy with sparse Jacobians & Hessians), but I don't think I'll have much time in the near future to revamp the Lux tests. Still, I think it would make sense to offer DI at least as a high level interface, even if it is not yet used in the package internals / tests. It might also help you figure out #605 |
Note that for DI to work in full generality with ComponentArrays, I need SciML/ComponentArrays.jl#254 to be fixed. Otherwise Jacobians and Hessians will stay broken (the rest, in particular gradient, is independent from stacking) |
Yes I want to roll it out first as a high level interface when the inputs are |
Hey there Avik!
As you may know, I have been busy developing DifferentiationInterface.jl, and it's really starting to take shape.
I was wondering if it would be useful for Lux.jl as a dependency, in order to support a wider variety of autodiff backends defined by ADTypes.jl?
Looking at the code, it seems the main spot where AD comes up (beyond the docs and tutorials) is
Lux.Training
:Lux.jl/src/contrib/training.jl
Lines 93 to 106 in c27b9f5
Gradients are only implemented in the extensions for Zygote and Tracker:
Lux.jl/ext/LuxZygoteExt.jl
Lines 7 to 14 in c27b9f5
Lux.jl/ext/LuxTrackerExt.jl
Lines 25 to 33 in c27b9f5
While DifferentiationInterface.jl is not yet ready or registered, it has a few niceties like Enzyme support which might pique your interest. I'm happy to discuss with you and see what other features you might need.
The main one I anticipate is compatibility with ComponentArrays.jl (JuliaDiff/DifferentiationInterface.jl#54), and I'll try to add it soon.
cc @adrhill
The text was updated successfully, but these errors were encountered: