diff --git a/src/function.jl b/src/function.jl index 8f9dc46..bbf41ec 100644 --- a/src/function.jl +++ b/src/function.jl @@ -43,6 +43,71 @@ function that is not defined, an error is thrown. For more information on the use of automatic differentiation, see the documentation of the `AbstractADType` types. """ + + +function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLBase.NoAD, + p, num_cons = 0) + jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, p, args...) + hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, p, args...) for h in f.hess] + hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...) + cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p) + cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p) + cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p) + cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p) + cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p) + hess_prototype = f.hess_prototype === nothing ? nothing : + convert.(eltype(x), f.hess_prototype) + cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : + convert.(eltype(x), f.cons_jac_prototype) + cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : + [convert.(eltype(x), f.cons_hess_prototype[i]) + for i in 1:num_cons] + expr = symbolify(f.expr) + cons_expr = symbolify.(f.cons_expr) + + return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess, + hv = hv, + cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h, + hess_prototype = hess_prototype, + cons_jac_prototype = cons_jac_prototype, + cons_hess_prototype = cons_hess_prototype, + expr = expr, cons_expr = cons_expr, + sys = f.sys, + observed = f.observed) +end + +function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReInitCache, ::SciMLBase.NoAD, + num_cons = 0) + jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, cache.p, args...) + hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, cache.p, args...) for h in f.hess] + hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...) + cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p) + cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p) + cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, cache.p) + cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, cache.p) + cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p) + hess_prototype = f.hess_prototype === nothing ? nothing : + convert.(eltype(cache.u0), f.hess_prototype) + cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing : + convert.(eltype(cache.u0), f.cons_jac_prototype) + cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing : + [convert.(eltype(cache.u0), f.cons_hess_prototype[i]) + for i in 1:num_cons] + expr = symbolify(f.expr) + cons_expr = symbolify.(f.cons_expr) + + return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess, + hv = hv, + cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h, + hess_prototype = hess_prototype, + cons_jac_prototype = cons_jac_prototype, + cons_hess_prototype = cons_hess_prototype, + expr = expr, cons_expr = cons_expr, + sys = f.sys, + observed = f.observed) +end + + function instantiate_function(f, x, ::SciMLBase.NoAD, p, num_cons = 0) grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...) @@ -113,3 +178,5 @@ function instantiate_function(f, x, adtype::ADTypes.AbstractADType, adpkg = adtypestr[strtind:(open_brkt_ind - 1)] throw(ArgumentError("The passed automatic differentiation backend choice is not available. Please load the corresponding AD package $adpkg.")) end + +