Skip to content

Commit

Permalink
jvp vjp with DI
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Aug 4, 2024
1 parent 9813fd4 commit 6c15524
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 11 deletions.
70 changes: 66 additions & 4 deletions src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ end

function instantiate_function(
f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType,
p = SciMLBase.NullParameters(), num_cons = 0)
p = SciMLBase.NullParameters(), num_cons = 0;
fg = false, fgh = false, cons_vjp = false, cons_jvp = false)
function _f(θ)
return f(θ, p)[1]
end
Expand All @@ -38,6 +39,13 @@ function instantiate_function(
grad = (G, θ) -> f.grad(G, θ, p)
end

if fg == true
function fg!(res, θ)
(y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad)
return y
end
end

hess_sparsity = f.hess_prototype
hess_colors = f.hess_colorvec
if f.hess === nothing
Expand All @@ -49,6 +57,13 @@ function instantiate_function(
hess = (H, θ) -> f.hess(H, θ, p)
end

if fgh == true
function fgh!(G, H, θ)
(y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess)
return y
end
end

if f.hv === nothing
extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x)))
hv = function (H, θ, v)
Expand Down Expand Up @@ -86,6 +101,24 @@ function instantiate_function(
cons_j = (J, θ) -> f.cons_j(J, θ, p)
end

if f.cons_vjp === nothing && cons_vjp == true
extras_pullback = prepare_pullback(cons_oop, adtype, x)
function cons_vjp!(J, θ, v)
pullback!(cons_oop, J, adtype, θ, v, extras_pullback)
end
else
cons_vjp! = nothing
end

if f.cons_jvp === nothing && cons_jvp == true
extras_pushforward = prepare_pushforward(cons_oop, adtype, x)
function cons_jvp!(J, θ, v)
pushforward!(cons_oop, J, adtype, θ, v, extras_pushforward)
end
else
cons_jvp! = nothing
end

conshess_sparsity = f.cons_hess_prototype
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && f.cons_h === nothing
Expand All @@ -101,22 +134,51 @@ function instantiate_function(
cons_h = (res, θ) -> f.cons_h(res, θ, p)
end

function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons))
return σ * _f(x) + dot(λ, cons_oop(x))
end

lag_hess_prototype = f.lag_hess_prototype

if f.lag_h === nothing
lag_h = nothing # Consider implementing this
lag_extras = prepare_hessian(lagrangian, soadtype, x)
lag_hess_prototype = zeros(Bool, length(x), length(x))

function lag_h(H::AbstractMatrix, θ, σ, λ)
if σ == zero(eltype(θ))
cons_h(H, θ)
H *= λ
else
hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras)
end
end

function lag_h(h, θ, σ, λ)
H = eltype(θ).(lag_hess_prototype)
hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras)
k = 0
rows, cols, _ = findnz(H)
for (i, j) in zip(rows, cols)
if i <= j
k += 1
h[k] = H[i, j]
end
end
end
else
lag_h = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p)
end

return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
cons = cons, cons_j = cons_j, cons_h = cons_h, cons_jvp = cons_jvp!, cons_jvp = cons_jvp!,
hess_prototype = hess_sparsity,
hess_colorvec = hess_colors,
cons_jac_prototype = cons_jac_prototype,
cons_jac_colorvec = cons_jac_colorvec,
cons_hess_prototype = conshess_sparsity,
cons_hess_colorvec = conshess_colors,
lag_h,
lag_hess_prototype = f.lag_hess_prototype,
lag_hess_prototype = lag_hess_prototype,
sys = f.sys,
expr = f.expr,
cons_expr = f.cons_expr)
Expand Down
14 changes: 7 additions & 7 deletions src/OptimizationDISparseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ function instantiate_function(
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && f.cons_h === nothing
fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons]
# extras_cons_hess = Vector(undef, length(fncs))
# for ind in 1:num_cons
# extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x)
# end
# conshess_sparsity = getfield.(extras_cons_hess, :sparsity)
# conshess_colors = getfield.(extras_cons_hess, :colors)
extras_cons_hess = Vector(undef, length(fncs))
for ind in 1:num_cons
extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x)
end
conshess_sparsity = getfield.(extras_cons_hess, :sparsity)
conshess_colors = getfield.(extras_cons_hess, :colors)
function cons_h(H, θ)
for i in 1:num_cons
hessian!(fncs[i], H[i], soadtype, θ)
hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i])
end
end
else
Expand Down

0 comments on commit 6c15524

Please sign in to comment.