Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ThummeTo committed Feb 2, 2024
1 parent 9c79a1d commit 7fb7fd3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
39 changes: 30 additions & 9 deletions src/FMI2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1256,12 +1256,33 @@ function color!(sens::FMU2Sensitivities)
return nothing
end

function ref_length(ref::AbstractArray)
return length(ref)
end

function ref_length(ref::Symbol)
if ref == :time
return 1
else
@assert false "unknwon ref symbol: $(ref)"
end
end

function ref_length(ref::Tuple)
@assert length(ref) == 2 "tuple ref length is $(length(ref)) != 2"
if ref[1] == :indicators
return length(ref[2])
else
@assert false "unknwon tuple ref $(ref)"
end
end

function update!(jac::FMU2Jacobian, x)

if size(jac.mtx) != (length(jac.f_refs), length(jac.x_refs))
jac.mtx = similar(jac.mtx, length(jac.f_refs), length(jac.x_refs))
jac.jvp = similar(jac.jvp, length(jac.f_refs))
jac.vjp = similar(jac.vjp, length(jac.x_refs))
if size(jac.mtx) != (ref_length(jac.f_refs), ref_length(jac.x_refs))
jac.mtx = similar(jac.mtx, ref_length(jac.f_refs), ref_length(jac.x_refs))
jac.jvp = similar(jac.jvp, ref_length(jac.f_refs))
jac.vjp = similar(jac.vjp, ref_length(jac.x_refs))

jac.valid = false
end
Expand All @@ -1278,12 +1299,12 @@ end

function update!(gra::FMU2Gradient, x)

if length(gra.vec) != length(jac.f_refs)
gra.vec = similar(gra.vec, length(jac.f_refs))
gra.gvp = similar(gra.gvp, length(jac.f_refs))
gra.vgp = similar(gra.vgp, length(jac.x_refs))
if length(gra.vec) != ref_length(gra.f_refs)
gra.vec = similar(gra.vec, ref_length(jac.f_refs))
gra.gvp = similar(gra.gvp, ref_length(jac.f_refs))
gra.vgp = similar(gra.vgp, ref_length(jac.x_refs))

jac.valid = false
gra.valid = false
end

if !gra.valid
Expand Down
2 changes: 1 addition & 1 deletion test/FMI2/jacobians_gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ p = fmi2GetReal(c, p_refs)
e = fmi2GetEventIndicators(c)
t = 0.0

function reset!(c::FMIImport.FMU2Component)
reset! = function(c::FMIImport.FMU2Component)
c.solution.evals_∂ẋ_∂x = 0
c.solution.evals_∂ẋ_∂u = 0
c.solution.evals_∂ẋ_∂p = 0
Expand Down
3 changes: 0 additions & 3 deletions test/FMI2/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ import FMISensitivity.ReverseDiff

CHECK_ZYGOTE = false

function euler_integrate(tStart, tStop)
end

# load demo FMU
fmu = fmi2Load("SpringPendulumExtForce1D", EXPORTINGTOOL, EXPORTINGVERSION; type=:ME)

Expand Down

0 comments on commit 7fb7fd3

Please sign in to comment.