Skip to content

Commit

Permalink
Performance improvements (#97)
Browse files Browse the repository at this point in the history
* PrepareValueReference multiple dispatch

* Reduce allocations

* Unsense-related allocations

* Minor change

* Optimise Jacobian invalidation (requires update in FMICore)

* Type stability

* Keyword and non-keyword definitions

* Avoid allocations

* Avoid c.x pointing to external array
  • Loading branch information
CasBex authored Sep 5, 2023
1 parent 7f75ca1 commit 6772986
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 97 deletions.
26 changes: 17 additions & 9 deletions src/FMI2/c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -597,18 +597,26 @@ function fmi2SetReal(c::FMU2Component,
c.compAddr, vr, nvr, value)
checkStatus(c, status)

if track
if status == fmi2StatusOK
for j in (c.A, c.B, c.C, c.D, c.E, c.F)
if any(collect(v in j.∂f_refs for v in vr))
FMICore.invalidate!(j)
end
end
end
if track && status == fmi2StatusOK
track_jac(vr, c.A)
track_jac(vr, c.B)
track_jac(vr, c.C)
track_jac(vr, c.D)
track_jac(vr, c.E)
track_jac(vr, c.F)
end

return status
end
function track_jac(vr::A, M::FMICore.FMUJacobian{V,R}) where {A<:AbstractArray{fmi2ValueReference},V,R}
for v in vr
if v in M.∂f_refsset
FMICore.invalidate!(M)
return
end
end
return
end

"""
fmi2GetInteger!(c::FMU2Component, vr::AbstractArray{fmi2ValueReference}, nvr::Csize_t, value::AbstractArray{fmi2Integer})
Expand Down Expand Up @@ -1618,7 +1626,7 @@ function fmi2SetContinuousStates(c::FMU2Component,

if track
if status == fmi2StatusOK
c.x = copy(x)
isnothing(c.x) ? (c.x = copy(x);) : copyto!(c.x, x)

FMICore.invalidate!(c.A)
FMICore.invalidate!(c.C)
Expand Down
55 changes: 21 additions & 34 deletions src/FMI2/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,30 @@
using ChainRulesCore: ignore_derivatives
import SciMLSensitivity.ForwardDiff

# ToDo: Replace by multiple dispatch version ...
# Receives one or an array of value references in an arbitrary format (see fmi2ValueReferenceFormat) and converts it into an Array{fmi2ValueReference} (if not already).
function prepareValueReference(md::fmi2ModelDescription, vr::fmi2ValueReferenceFormat)
tvr = typeof(vr)
if isa(vr, AbstractArray{fmi2ValueReference,1})
return vr
elseif tvr == fmi2ValueReference
return [vr]
elseif tvr == String
return [fmi2StringToValueReference(md, vr)]
elseif isa(vr, AbstractArray{String,1})
return fmi2StringToValueReference(md, vr)
elseif tvr == Int64
return [fmi2ValueReference(vr)]
elseif isa(vr, AbstractArray{Int64,1})
return fmi2ValueReference.(vr)
elseif tvr == Nothing
prepareValueReference(md::fmi2ModelDescription, vr::AbstractVector{fmi2ValueReference}) = vr
prepareValueReference(md::fmi2ModelDescription, vr::fmi2ValueReference) = [vr]
prepareValueReference(md::fmi2ModelDescription, vr::String) = [fmi2StringToValueReference(md, vr)]
prepareValueReference(md::fmi2ModelDescription, vr::AbstractVector{String}) = fmi2StringToValueReference(md, vr)
prepareValueReference(md::fmi2ModelDescription, vr::AbstractVector{<:Integer}) = fmi2ValueReference.(vr)
prepareValueReference(md::fmi2ModelDescription, vr::Integer) = [fmi2ValueReference(vr)]
prepareValueReference(md::fmi2ModelDescription, vr::Nothing) = fmi2ValueReference[]
function prepareValueReference(md::fmi2ModelDescription, vr::Symbol)
if vr == :states
return md.stateValueReferences
elseif vr == :derivatives
return md.derivativeValueReferences
elseif vr == :inputs
return md.inputValueReferences
elseif vr == :outputs
return md.outputValueReferences
elseif vr == :all
return md.valueReferences
elseif vr == :none
return Array{fmi2ValueReference,1}()
elseif tvr == Symbol
if vr == :states
return md.stateValueReferences
elseif vr == :derivatives
return md.derivativeValueReferences
elseif vr == :inputs
return md.inputValueReferences
elseif vr == :outputs
return md.outputValueReferences
elseif vr == :all
return md.valueReferences
elseif vr == :none
return Array{fmi2ValueReference,1}()
else
@assert false "Unknwon symbol `$vr`, can't convert to value reference."
end
else
@assert false "Unknwon symbol `$vr`, can't convert to value reference."
end

@assert false "prepareValueReference(...): Unknown value reference structure `$tvr`."
end
function prepareValueReference(fmu::FMU2, vr::fmi2ValueReferenceFormat)
prepareValueReference(fmu.modelDescription, vr)
Expand Down
27 changes: 11 additions & 16 deletions src/FMI2/int.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,12 @@ More detailed:
- FMISpec2.0.2[p.18]: 2.1.3 Status Returned by Functions
See also [`fmi2SetReal`](@ref).
"""
function fmi2SetReal(c::FMU2Component, vr::fmi2ValueReferenceFormat, values::Union{AbstractArray{<:Real}, <:Real}; kwargs...)

vr = prepareValueReference(c, vr)
values = prepareValue(values)
function fmi2SetReal(c::FMU2Component, vr::fmi2ValueReferenceFormat, values::AbstractVector{fmi2Real}; kwargs...)
@assert length(vr) == length(values) "fmi2SetReal(...): `vr` ($(length(vr))) and `values` ($(length(values))) need to be the same length."

nvr = Csize_t(length(vr))
fmi2SetReal(c, vr, nvr, Array{fmi2Real}(values); kwargs...)
fmi2SetReal(c, prepareValueReference(c, vr), nvr, prepareValue(values); kwargs...)
end
fmi2SetReal(c::FMU2Component, vr::fmi2ValueReferenceFormat, values::Real; kwargs...) = fmi2SetReal(c, prepareValueReference(c, vr), prepareValue(values); kwargs...)

"""
fmi2GetInteger(c::FMU2Component, vr::fmi2ValueReferenceFormat)
Expand Down Expand Up @@ -1118,14 +1115,15 @@ More detailed:
- FMISpec2.0.2[p.83]: 3.2.1 Providing Independent Variables and Re-initialization of Caching
See also [`fmi2SetContinuousStates`](@ref).
"""
function fmi2SetContinuousStates(c::FMU2Component, x::Union{AbstractArray{Float32}, AbstractArray{Float64}}; kwargs...)
function fmi2SetContinuousStates(c::FMU2Component, x::AbstractArray{fmi2Real}; kwargs...)
nx = Csize_t(length(x))
status = fmi2SetContinuousStates(c, Array{fmi2Real}(x), nx; kwargs...)
status = fmi2SetContinuousStates(c, x, nx; kwargs...)
if status == fmi2StatusOK
c.x = x
isnothing(c.x) ? (c.x = copy(x);) : copyto!(c.x, x)
end
return status
end
fmi2SetContinuousStates(c::FMU2Component, x::AbstractArray{Float32}; kwargs...) = fmi2SetContinuousStates(c, Array{fmi2Real}(x); kwargs...)

"""
fmi2NewDiscreteStates(c::FMU2Component)
Expand Down Expand Up @@ -1187,15 +1185,12 @@ See also [`fmi2CompletedIntegratorStep`](@ref).
"""
function fmi2CompletedIntegratorStep(c::FMU2Component,
noSetFMUStatePriorToCurrentPoint::fmi2Boolean)
enterEventMode = zeros(fmi2Boolean, 1)
terminateSimulation = zeros(fmi2Boolean, 1)

status = fmi2CompletedIntegratorStep!(c,
noSetFMUStatePriorToCurrentPoint,
pointer(enterEventMode),
pointer(terminateSimulation))
c.ptr_stepEnterEventMode,
c.ptr_terminateSimulation)

return (status, enterEventMode[1], terminateSimulation[1])
return (status, c.stepEnterEventMode, c.terminateSimulation)
end

"""
Expand Down Expand Up @@ -1362,4 +1357,4 @@ function fmi2GetStatus(c::FMU2Component, s::fmi2StatusKind)
end

status, value[1]
end
end
87 changes: 50 additions & 37 deletions src/FMI2/sens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,15 @@ Not all options are available for any FMU type, e.g. setting state is not suppor
- `y::Union{AbstractVector{<:Real}, Nothing}`: The system output `y` (if requested, otherwise `nothing`).
- `dx::Union{AbstractVector{<:Real}, Nothing}`: The system state-derivaitve (if ME-FMU, otherwise `nothing`).
"""
function (fmu::FMU2)(;dx::AbstractVector{<:Real}=Vector{fmi2Real}(),
y::AbstractVector{<:Real}=Vector{fmi2Real}(),
y_refs::AbstractVector{<:fmi2ValueReference}=Vector{fmi2ValueReference}(),
x::AbstractVector{<:Real}=Vector{fmi2Real}(),
u::AbstractVector{<:Real}=Vector{fmi2Real}(),
u_refs::AbstractVector{<:fmi2ValueReference}=Vector{fmi2ValueReference}(),
p::AbstractVector{<:Real}=fmu.optim_p,
p_refs::AbstractVector{<:fmi2ValueReference}=fmu.optim_p_refs,
t::Real=-1.0)

c = nothing
function (fmu::FMU2)(dx::AbstractVector{<:Real},
y::AbstractVector{<:Real},
y_refs::AbstractVector{<:fmi2ValueReference},
x::AbstractVector{<:Real},
u::AbstractVector{<:Real},
u_refs::AbstractVector{<:fmi2ValueReference},
p::AbstractVector{<:Real},
p_refs::AbstractVector{<:fmi2ValueReference},
t::Real)

if hasCurrentComponent(fmu)
c = getCurrentComponent(fmu)
Expand All @@ -97,6 +95,18 @@ function (fmu::FMU2)(;dx::AbstractVector{<:Real}=Vector{fmi2Real}(),
c(;dx=dx, y=y, y_refs=y_refs, x=x, u=u, u_refs=u_refs, p=p, p_refs=p_refs, t=t)
end

function (fmu::FMU2)(;dx::AbstractVector{<:Real}=Vector{fmi2Real}(),
y::AbstractVector{<:Real}=Vector{fmi2Real}(),
y_refs::AbstractVector{<:fmi2ValueReference}=Vector{fmi2ValueReference}(),
x::AbstractVector{<:Real}=Vector{fmi2Real}(),
u::AbstractVector{<:Real}=Vector{fmi2Real}(),
u_refs::AbstractVector{<:fmi2ValueReference}=Vector{fmi2ValueReference}(),
p::AbstractVector{<:Real}=fmu.optim_p,
p_refs::AbstractVector{<:fmi2ValueReference}=fmu.optim_p_refs,
t::Real=-1.0)
(fmu)(dx, y, y_refs, x, u, u_refs, p, p_refs, t)
end

"""
(c::FMU2Component)(;dx::Union{AbstractVector{<:Real}, Nothing}=nothing,
Expand Down Expand Up @@ -125,15 +135,16 @@ Not all options are available for any FMU type, e.g. setting state is not suppor
- `y::Union{AbstractVector{<:Real}, Nothing}`: The system output `y` (if requested, otherwise `nothing`).
- `dx::Union{AbstractVector{<:Real}, Nothing}`: The system state-derivaitve (if ME-FMU, otherwise `nothing`).
"""
function (c::FMU2Component)(;dx::AbstractVector{<:Real}=Vector{fmi2Real}(),
y::AbstractVector{<:Real}=Vector{fmi2Real}(),
y_refs::AbstractVector{<:fmi2ValueReference}=Vector{fmi2ValueReference}(),
x::AbstractVector{<:Real}=Vector{fmi2Real}(),
u::AbstractVector{<:Real}=Vector{fmi2Real}(),
u_refs::AbstractVector{<:fmi2ValueReference}=Vector{fmi2ValueReference}(),
p::AbstractVector{<:Real}=c.fmu.optim_p,
p_refs::AbstractVector{<:fmi2ValueReference}=c.fmu.optim_p_refs,
t::Real=c.next_t)
function (c::FMU2Component)(dx::AbstractVector{<:Real},
y::AbstractVector{<:Real},
y_refs::AbstractVector{<:fmi2ValueReference},
x::AbstractVector{<:Real},
u::AbstractVector{<:Real},
u_refs::AbstractVector{<:fmi2ValueReference},
p::AbstractVector{<:Real},
p_refs::AbstractVector{<:fmi2ValueReference},
t::Real)


if length(y_refs) > 0
if length(y) <= 0
Expand Down Expand Up @@ -172,6 +183,18 @@ function (c::FMU2Component)(;dx::AbstractVector{<:Real}=Vector{fmi2Real}(),
return eval!(cRef, dx, y, y_refs, x, u, u_refs, p, p_refs, t)
end

function (c::FMU2Component)(;dx::AbstractVector{<:Real}=Vector{fmi2Real}(),
y::AbstractVector{<:Real}=Vector{fmi2Real}(),
y_refs::AbstractVector{<:fmi2ValueReference}=Vector{fmi2ValueReference}(),
x::AbstractVector{<:Real}=Vector{fmi2Real}(),
u::AbstractVector{<:Real}=Vector{fmi2Real}(),
u_refs::AbstractVector{<:fmi2ValueReference}=Vector{fmi2ValueReference}(),
p::AbstractVector{<:Real}=c.fmu.optim_p,
p_refs::AbstractVector{<:fmi2ValueReference}=c.fmu.optim_p_refs,
t::Real=c.next_t)
(c)(dx, y, y_refs, x, u, u_refs, p, p_refs, t)
end

function eval!(cRef::UInt64,
dx::AbstractVector{<:Real},
y::AbstractVector{<:Real},
Expand All @@ -190,41 +213,31 @@ function eval!(cRef::UInt64,
@assert (!isdual(t) && !istracked(t)) "eval!(...): Wrong dispatched: `t` is ForwardDiff.Dual/ReverseDiff.TrackedReal, please open an issue with MWE."
@assert (!isdual(p) && !istracked(p)) "eval!(...): Wrong dispatched: `p` is ForwardDiff.Dual/ReverseDiff.TrackedReal, please open an issue with MWE."

x = unsense(x)
t = unsense(t)
u = unsense(u)
# p = unsense(p) # no need to unsense `p` because it is not beeing used further

# set state
if length(x) > 0 && !c.fmu.isZeroState
fmi2SetContinuousStates(c, x)
fmi2SetContinuousStates(c, unsense(x))
end

# set time
if t >= 0.0
fmi2SetTime(c, t)
fmi2SetTime(c, unsense(t))
end

# set input
if length(u) > 0
fmi2SetReal(c, u_refs, u)
fmi2SetReal(c, u_refs, unsense(u))
end

# get derivative
if length(dx) > 0
if isdual(dx)

dx_tmp = nothing

if c.fmu.isZeroState
dx_tmp = [1.0]
else
dx_tmp = collect(ForwardDiff.value(e) for e in dx)
fmi2GetDerivatives!(c, dx_tmp)
dx_tmp = ForwardDiff.value.(dx)
c.fmu.isZeroState || fmi2GetDerivatives!(c, dx_tmp)
for i = 1:length(dx)
dx[i].value = dx_tmp[i]
end

T, V, N = fd_eltypes(dx)
dx[:] = collect(ForwardDiff.Dual{T, V, N}(dx_tmp[i], ForwardDiff.partials(dx[i]) ) for i in 1:length(dx))

elseif istracked(dx)

Expand Down
5 changes: 4 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ end

# makes Reals from ForwardDiff/ReverseDiff.TrackedXXX scalar/vector
function unsense(e::AbstractArray)
return collect(unsense(c) for c in e)
return unsense.(e)
end
function unsense(e::AbstractArray{fmi2Real})
return e
end
function unsense(e::Tuple)
return (collect(unsense(c) for c in e)...,)
Expand Down

0 comments on commit 6772986

Please sign in to comment.