Skip to content

Commit

Permalink
Merge pull request #75 from ParasPuneetSingh/main
Browse files Browse the repository at this point in the history
Added MOO functionality to functions.jl
  • Loading branch information
Vaibhavdixit02 authored Aug 12, 2024
2 parents 4fb7a80 + 32c0a1a commit 8f6be9d
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions src/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,71 @@ function that is not defined, an error is thrown.
For more information on the use of automatic differentiation, see the
documentation of the `AbstractADType` types.
"""


function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLBase.NoAD,
p, num_cons = 0)
jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, p, args...)
hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, p, args...) for h in f.hess]
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p)
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p)
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
hess_prototype = f.hess_prototype === nothing ? nothing :
convert.(eltype(x), f.hess_prototype)
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
convert.(eltype(x), f.cons_jac_prototype)
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
[convert.(eltype(x), f.cons_hess_prototype[i])
for i in 1:num_cons]
expr = symbolify(f.expr)
cons_expr = symbolify.(f.cons_expr)

return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
hv = hv,
cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h,
hess_prototype = hess_prototype,
cons_jac_prototype = cons_jac_prototype,
cons_hess_prototype = cons_hess_prototype,
expr = expr, cons_expr = cons_expr,
sys = f.sys,
observed = f.observed)
end

function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReInitCache, ::SciMLBase.NoAD,
num_cons = 0)
jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, cache.p, args...)
hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, cache.p, args...) for h in f.hess]
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...)
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p)
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p)
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, cache.p)
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, cache.p)
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p)
hess_prototype = f.hess_prototype === nothing ? nothing :
convert.(eltype(cache.u0), f.hess_prototype)
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
convert.(eltype(cache.u0), f.cons_jac_prototype)
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
[convert.(eltype(cache.u0), f.cons_hess_prototype[i])
for i in 1:num_cons]
expr = symbolify(f.expr)
cons_expr = symbolify.(f.cons_expr)

return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
hv = hv,
cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h,
hess_prototype = hess_prototype,
cons_jac_prototype = cons_jac_prototype,
cons_hess_prototype = cons_hess_prototype,
expr = expr, cons_expr = cons_expr,
sys = f.sys,
observed = f.observed)
end


function instantiate_function(f, x, ::SciMLBase.NoAD,
p, num_cons = 0)
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
Expand Down Expand Up @@ -113,3 +178,5 @@ function instantiate_function(f, x, adtype::ADTypes.AbstractADType,
adpkg = adtypestr[strtind:(open_brkt_ind - 1)]
throw(ArgumentError("The passed automatic differentiation backend choice is not available. Please load the corresponding AD package $adpkg."))
end


0 comments on commit 8f6be9d

Please sign in to comment.