Skip to content

Commit

Permalink
Add statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2023
1 parent 8671a2d commit 8e4cd87
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
7 changes: 4 additions & 3 deletions src/neural_de.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,16 @@ function (n::NeuralDAE)(u_du::Tuple, p, st)
nn_out = model(vcat(u, du), p)
alg_out = n.constraints_model(u, p, t)
iter_nn, iter_const = 0, 0
map(n.differential_vars) do isdiff
res = map(n.differential_vars) do isdiff
if isdiff
iter_nn += 1
selectdim(nn_out, 1, iter_nn)
nn_out[iter_nn]
else
iter_const += 1
selectdim(alg_out, 1, iter_const)
alg_out[iter_const]
end
end
return res
end

prob = DAEProblem{false}(f, du0, u0, n.tspan, p; n.differential_vars)
Expand Down
10 changes: 5 additions & 5 deletions test/neural_dae.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ComponentArrays, DiffEqFlux, Zygote, Optimization, OrdinaryDiffEq, Random
using ComponentArrays,
DiffEqFlux, Zygote, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Random

#A desired MWE for now, not a test yet.

Expand Down Expand Up @@ -27,6 +28,7 @@ tspan = (0.0, 10.0)
ndae = NeuralDAE(dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan, DImplicitEuler();
differential_vars = [true, true, false])
ps, st = Lux.setup(Xoshiro(0), ndae)
ps = ComponentArray(ps)
truedu0 = similar(u₀)

ndae((u₀, truedu0), ps, st)
Expand All @@ -36,13 +38,11 @@ predict_n_dae(p) = first(ndae(u₀, p, st))
function loss(p)
pred = predict_n_dae(p)
loss = sum(abs2, sol .- pred)
loss, pred
return loss, pred
end

p = p .+ rand(3) .* p

optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optfunc, p)
optprob = Optimization.OptimizationProblem(optfunc, ps)
res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.0001))

# Same stuff with Lux
Expand Down
2 changes: 1 addition & 1 deletion test/neural_gde.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DiffEqFlux, ComponentArrays, GeometricFlux, GraphSignals, OrdinaryDiffEq, Random,
Test, OptimizationOptimisers, Optimization
Test, OptimizationOptimisers, Optimization, Statistics
import Flux

# Fully Connected Graph
Expand Down
3 changes: 2 additions & 1 deletion test/second_order_ode.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ComponentArrays, DiffEqFlux, Lux, Zygote, Random, Optimization, OrdinaryDiffEq
using ComponentArrays,
DiffEqFlux, Lux, Zygote, Random, Optimization, OptimizationOptimisers, OrdinaryDiffEq

rng = Random.default_rng()

Expand Down

0 comments on commit 8e4cd87

Please sign in to comment.