Skip to content

Commit

Permalink
hessian of lagrangian
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Aug 2, 2024
1 parent eb74297 commit 6d4f6d7
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 46 deletions.
49 changes: 47 additions & 2 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module OptimizationEnzymeExt
import OptimizationBase, OptimizationBase.ArrayInterface
import OptimizationBase.SciMLBase: OptimizationFunction
import OptimizationBase.SciMLBase
import OptimizationBase.LinearAlgebra: I
import OptimizationBase.LinearAlgebra: I, dot
import OptimizationBase.ADTypes: AutoEnzyme
using Enzyme
using Core: Vararg
Expand Down Expand Up @@ -76,6 +76,18 @@ function cons_f2_oop(x, dx, fcons, p, i)
return nothing
end

function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))::Float64
res = zeros(eltype(x), length(λ))
cons(res, x, p)
return σ * _f(x, p) + dot(λ, res)
end

function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
Enzyme.autodiff_deferred(Enzyme.Reverse, lagrangian, Active, Enzyme.Duplicated(x, dx),
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
return nothing
end

function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
adtype::AutoEnzyme, p,
num_cons = 0)
Expand Down Expand Up @@ -219,7 +231,40 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end

if f.lag_h === nothing
lag_h = nothing # Consider implementing this
lag_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
lag_bθ = zeros(eltype(x), length(x))

if f.hess_prototype === nothing
lag_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x))
else
#useless right now, looks like there is no way to tell Enzyme the sparsity pattern?
lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype)))
end

function lag_h(h, θ, σ, μ)
Enzyme.make_zero!.(lag_vdθ)
Enzyme.make_zero!(lag_bθ)
Enzyme.make_zero!.(lag_vdbθ)

Enzyme.autodiff(Enzyme.Forward,
lag_grad,
Enzyme.BatchDuplicated(θ, lag_vdθ),
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
Const(lagrangian),
Const(f.f),
Const(f.cons),
Const(p),
Const(σ),
Const(μ)
)
k = 0

for i in eachindex(θ)
vec_lagv = lag_vdbθ[i]
h[k+1:k+i] .= @view(vec_lagv[1:i])
k += i
end
end
else
lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p)
end
Expand Down
65 changes: 21 additions & 44 deletions src/OptimizationDISparseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ function instantiate_function(
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && f.cons_h === nothing
fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons]
extras_cons_hess = Vector{DifferentiationInterface.SparseHessianExtras}(undef, length(fncs))
for ind in 1:num_cons
extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x)
end
conshess_sparsity = [sum(sparse, cons)]
conshess_colors = getfield.(extras_cons_hess, Ref(:colors))
# extras_cons_hess = Vector(undef, length(fncs))
# for ind in 1:num_cons
# extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x)
# end
# conshess_sparsity = getfield.(extras_cons_hess, :sparsity)
# conshess_colors = getfield.(extras_cons_hess, :colors)
function cons_h(H, θ)
for i in 1:num_cons
hessian!(fncs[i], H[i], soadtype, θ)
Expand All @@ -189,56 +189,33 @@ function instantiate_function(
cons_h = (res, θ) -> f.cons_h(res, θ, p)
end

function lagrangian(x, σ = one(eltype(x)))
θ = x[1:end-num_cons]
λ = x[end-num_cons+1:end]
return σ * _f(θ) + dot(λ, cons_oop(θ))
function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons))
return σ * _f(x) + dot(λ, cons_oop(x))
end

lag_hess_prototype = f.lag_hess_prototype
if f.lag_h === nothing
lag_extras = prepare_hessian(lagrangian, soadtype, vcat(x, ones(eltype(x), num_cons)))
lag_extras = prepare_hessian(lagrangian, soadtype, x)
lag_hess_prototype = lag_extras.sparsity

function lag_h(H::Matrix, θ, σ, λ)
@show size(H)
@show size(θ)
@show size(λ)

function lag_h(H::AbstractMatrix, θ, σ, λ)
if σ == zero(eltype(θ))
cons_h(H, θ)
H *= λ
else
hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras)
hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras)
end
end

function lag_h(h, θ, σ, λ)
# @show h
sparseHproto = findnz(lag_extras.sparsity)
H = sparse(sparseHproto[1], sparseHproto[2], zeros(eltype(θ), length(sparseHproto[1])))
if σ == zero(eltype(θ))
cons_h(H, θ)
H *= λ
else
hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras)
k = 0
rows, cols, _ = findnz(H)
for (i, j) in zip(rows, cols)
if i <= j
k += 1
h[k] = σ * H[i, j]
end
end
k = 0
for λi in λ
if Hi isa SparseMatrixCSC
rows, cols, _ = findnz(Hi)
for (i, j) in zip(rows, cols)
if i <= j
k += 1
h[k] += λi * Hi[i, j]
end
end
end
H = eltype(θ).(lag_hess_prototype)
hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras)
k = 0
rows, cols, _ = findnz(H)
for (i, j) in zip(rows, cols)
if i <= j
k += 1
h[k] = H[i, j]
end
end
end
Expand Down

0 comments on commit 6d4f6d7

Please sign in to comment.