You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We are developing a package that uses ComponentArrays and Lux to train Neural ODEs in a simple to use front end. We found the following bug after one of our devs updated his packages. Upon investigation, it seems to be an issue of ComponentArrays. The following MWE:
using ComponentArrays, Lux, Random, OrdinaryDiffEq, Zygote, SciMLSensitivity
df =rand(2,10)
rng = Random.default_rng()
NN = Lux.Chain(Lux.Dense(2,10,tanh), Lux.Dense(10,2))
rng = Random.default_rng()
parameters, states = Lux.setup(rng,NN)
parameters = (NN = parameters, )
functionderivs!(du,u,parameters,t)
du .=NN(u,parameters.NN,states)[1]
return du
end
u0 =zeros(2); tspan = (0.0,0.5)
IVP =ODEProblem(derivs!, u0, tspan, parameters)
functionpredict(u,t,dt,parameters)
tspan = (t,t+dt)
sol =solve(IVP, Tsit5(), u0 = u, p=parameters,tspan = tspan, saveat = (t,t+dt))
X =Array(sol)
return X[:,end]
endfunctionloss(parameters)
sum(abs2,predict(df[:,1],0.0,0.05,parameters) .- df[:,2])
endgradient(loss,ComponentArray(parameters))
fails with the following ]status:
⌃ [b0b7db55] ComponentArrays v0.15.14
⌃ [b2108857] Lux v0.5.61
[9a3f8284] Random
but works with the following ]status:
⌃ [b0b7db55] ComponentArrays v0.15.13
⌃ [b2108857] Lux v0.5.61
[9a3f8284] Random
I'm testing on Julia 1.10.4. When it fails, it throws the following stacktrace:
We are developing a package that uses
ComponentArrays
andLux
to train Neural ODEs in a simple to use front end. We found the following bug after one of our devs updated his packages. Upon investigation, it seems to be an issue ofComponentArrays
. The following MWE:fails with the following
]status
:but works with the following
]status
:I'm testing on Julia 1.10.4. When it fails, it throws the following stacktrace:
The text was updated successfully, but these errors were encountered: