From 7fb7fd33209a3c59bed1a90bb58b5629befec500 Mon Sep 17 00:00:00 2001 From: TT Date: Fri, 2 Feb 2024 16:28:16 +0100 Subject: [PATCH] WIP --- src/FMI2.jl | 39 ++++++++++++++++++++++++-------- test/FMI2/jacobians_gradients.jl | 2 +- test/FMI2/solution.jl | 3 --- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/src/FMI2.jl b/src/FMI2.jl index 9852ff7..78b29a4 100644 --- a/src/FMI2.jl +++ b/src/FMI2.jl @@ -1256,12 +1256,33 @@ function color!(sens::FMU2Sensitivities) return nothing end +function ref_length(ref::AbstractArray) + return length(ref) +end + +function ref_length(ref::Symbol) + if ref == :time + return 1 + else + @assert false "unknwon ref symbol: $(ref)" + end +end + +function ref_length(ref::Tuple) + @assert length(ref) == 2 "tuple ref length is $(length(ref)) != 2" + if ref[1] == :indicators + return length(ref[2]) + else + @assert false "unknwon tuple ref $(ref)" + end +end + function update!(jac::FMU2Jacobian, x) - if size(jac.mtx) != (length(jac.f_refs), length(jac.x_refs)) - jac.mtx = similar(jac.mtx, length(jac.f_refs), length(jac.x_refs)) - jac.jvp = similar(jac.jvp, length(jac.f_refs)) - jac.vjp = similar(jac.vjp, length(jac.x_refs)) + if size(jac.mtx) != (ref_length(jac.f_refs), ref_length(jac.x_refs)) + jac.mtx = similar(jac.mtx, ref_length(jac.f_refs), ref_length(jac.x_refs)) + jac.jvp = similar(jac.jvp, ref_length(jac.f_refs)) + jac.vjp = similar(jac.vjp, ref_length(jac.x_refs)) jac.valid = false end @@ -1278,12 +1299,12 @@ end function update!(gra::FMU2Gradient, x) - if length(gra.vec) != length(jac.f_refs) - gra.vec = similar(gra.vec, length(jac.f_refs)) - gra.gvp = similar(gra.gvp, length(jac.f_refs)) - gra.vgp = similar(gra.vgp, length(jac.x_refs)) + if length(gra.vec) != ref_length(gra.f_refs) + gra.vec = similar(gra.vec, ref_length(jac.f_refs)) + gra.gvp = similar(gra.gvp, ref_length(jac.f_refs)) + gra.vgp = similar(gra.vgp, ref_length(jac.x_refs)) - jac.valid = false + gra.valid = false end if !gra.valid diff --git a/test/FMI2/jacobians_gradients.jl b/test/FMI2/jacobians_gradients.jl index f3ee3da..089ed41 100644 --- a/test/FMI2/jacobians_gradients.jl +++ b/test/FMI2/jacobians_gradients.jl @@ -31,7 +31,7 @@ p = fmi2GetReal(c, p_refs) e = fmi2GetEventIndicators(c) t = 0.0 -function reset!(c::FMIImport.FMU2Component) +reset! = function(c::FMIImport.FMU2Component) c.solution.evals_∂ẋ_∂x = 0 c.solution.evals_∂ẋ_∂u = 0 c.solution.evals_∂ẋ_∂p = 0 diff --git a/test/FMI2/solution.jl b/test/FMI2/solution.jl index e3b08e7..089ed41 100644 --- a/test/FMI2/solution.jl +++ b/test/FMI2/solution.jl @@ -10,9 +10,6 @@ import FMISensitivity.ReverseDiff CHECK_ZYGOTE = false -function euler_integrate(tStart, tStop) -end - # load demo FMU fmu = fmi2Load("SpringPendulumExtForce1D", EXPORTINGTOOL, EXPORTINGVERSION; type=:ME)