-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nested AD for Parameter Gradient/Jacobian #610
Comments
downgrading to |
Error: ERROR: LoadError: MethodError: no method matching fast_materialize(::Static.False, ::Static.False, ::Matrix{Float32})
Closest candidates are:
fast_materialize(::SB, ::DB, ::Base.Broadcast.Broadcasted{S}) where {S, SB, DB}
@ FastBroadcast C:\Users\prbzr\.julia\packages\FastBroadcast\ux5mz\src\FastBroadcast.jl:22
Stacktrace:
[1] macro expansion
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context{false}, ::typeof(FastBroadcast.fast_materialize), ::Static.False, ::Static.False, ::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:81
[3] __activation_gradient
@ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\utils.jl:187 [inlined]
[4] LuxDL/LuxLib.jl#44
@ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\impl\fused_dense.jl:45 [inlined]
[5] _pullback(ctx::Zygote.Context{false}, f::LuxLib.var"#44#47"{typeof(tanh_fast), typeof(identity), Matrix{…}, Base.ReshapedArray{…}, Matrix{…}, SubArray{…}}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[6] ZBack
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
[7] Pullback
@ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:46 [inlined]
[8] Pullback
@ C:\Users\prbzr\.julia\packages\LuxLib\t9w7i\src\api\dense.jl:38 [inlined]
[9] Pullback
@ C:\Users\prbzr\.julia\packages\Lux\ANzxX\src\layers\basic.jl:218 [inlined]
[10] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Tuple{Matrix{…}, Nothing})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[11] Pullback
@ D:\Codes\Mine\bug-report\br-3\br-3.jl:9 [inlined]
[12] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[13] #291
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
[14] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[15] #2169#back
@ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[16] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[17] Pullback
@ .\operators.jl:1045 [inlined]
[18] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{var"#1#2"}, Tuple{ComponentVector{Float32, Vector{…}, Tuple{…}}}, @Kwargs{}}, Any}, args::Matrix{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[19] Pullback
@ .\operators.jl:1044 [inlined]
[20] Pullback
@ .\operators.jl:1041 [inlined]
[21] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{…}, ComponentVector{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[22] #291
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
[23] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#291#292"{Tuple{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[24] #2169#back
@ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[25] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[26] Pullback
@ .\operators.jl:1041 [inlined]
[27] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[28] #75
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91 [inlined]
[29] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Vector{Float32})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[30] withjacobian
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\grad.jl:150 [inlined]
[31] _pullback(::Zygote.Context{false}, ::typeof(withjacobian), ::var"#1#2", ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[32] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:838
[33] adjoint
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:203 [inlined]
[34] _pullback
@ C:\Users\prbzr\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
[35] jacobian
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\lib\grad.jl:128 [inlined]
[36] _pullback(::Zygote.Context{false}, ::typeof(jacobian), ::var"#1#2", ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[37] fn1
@ D:\Codes\Mine\bug-report\br-3\br-3.jl:9 [inlined]
[38] _pullback(ctx::Zygote.Context{false}, f::typeof(fn1), args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[39] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:90
[40] pullback
@ C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:88 [inlined]
[41] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2))), bias = ViewAxis(5:6, ShapedAxis((2, 1))))}}})
@ Zygote C:\Users\prbzr\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:147
[42] top-level scope
@ D:\Codes\Mine\bug-report\br-3\br-3.jl:13
[43] include(fname::String)
@ Base.MainInclude .\client.jl:489
[44] top-level scope
@ REPL[1]:1
in expression starting at D:\Codes\Mine\bug-report\br-3\br-3.jl:13
Some type information was truncated. Use `show(err)` to see complete types. MRE: using ComponentArrays, Lux, Random, Zygote
nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
function fn1(z)
sum(first(Zygote.jacobian(x -> first(nn(r, x, st)), z)))
end
fn1(ps)
Zygote.gradient(fn1, ps) Environment: Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
[b0b7db55] ComponentArrays v0.15.11
[b2108857] Lux v0.5.40
[e88e6eb3] Zygote v0.6.69
[9a3f8284] Random |
Yeah that is some weird Zygote broadcast handling quirk. From v0.5.40 we use completely different backend operations which are faster and allocate significantly less but come at the cost of sacrificing nested reverse over reverse zygote AD (which to be fair, worked only in very limited cases and was never documented for a good reason) https://lux.csail.mit.edu/stable/manual/nested_autodiff does nested AD for the inputs, but the same for parameters hasn't been implemented yet. |
I met the same issue when taking Zygote gradient of pullback using Lux, Zygote, ComponentArrays, Random
X = collect(range(0, 1, length = 10)) |> permutedims
Y = zeros(axes(X))
nn = Chain(Dense(1 => 20, tanh), Dense(20 => 1))
ps, st = Lux.setup(Xoshiro(0), nn)
pv = ComponentArray(ps)
function loss(p)
u(x) = 1 .+ nn(x, p, st)[1] .* x
ux(x) = Zygote.pullback(u, x)[2](ones(size(x)))[1]
pred = ux(X)
loss = sum(abs2, pred)
return loss
end
Zygote.gradient(loss, pv) The code doesn't work since v0.5.40. |
I added [Lux]
DisableAutomaticNestedADSwitching = true as |
And also I didn't use |
This is a separate issue. See The core problem is still the same. Zygote modifies Broadcast.broadcasted operations in a strange way that doesn't allow using FastBroadcast and such. (This part has nothing to do with the Nested AD rules that were introduced but rather #591). To fix this we just need to introduce an rrule for the parameter jacobian on |
I'm getting this error for a code that one month ago was working:
I will update this!
The text was updated successfully, but these errors were encountered: