Skip to content

Commit

Permalink
minor adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
ThummeTo committed Sep 4, 2024
1 parent 32198ac commit 00197e8
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions src/sense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ function ChainRulesCore.rrule(::typeof(FMIBase.eval!),
# because they are evaluated at different points in time during ODE solving.
if length(c.solution.snapshots) > 0
sn = getSnapshot(c.solution, t)
apply!(c, sn)
if !isnothing(sn) # sometimes it is -Inf (whyever...)
apply!(c, sn)
end
end

Ω = FMIBase.eval!(cRef, dx, dx_refs, y, y_refs, x, u, u_refs, p, p_refs, ec, ec_idcs, t)
Expand Down Expand Up @@ -1003,7 +1005,7 @@ abstract type FMUSensitivities end
mutable struct FMUJacobian{C, T, F} <: FMUSensitivities
valid::Bool
colored::Bool
component::C
instance::C

mtx::Matrix{T}
jvp::Vector{T}
Expand All @@ -1021,7 +1023,7 @@ mutable struct FMUJacobian{C, T, F} <: FMUSensitivities
validations::Int
colorings::Int

function FMUJacobian{T}(component::C, f_refs::Union{Vector{UInt32}, Tuple{Symbol, Vector{UInt32}}}, x_refs::Union{Vector{UInt32}, Symbol}) where {C, T}
function FMUJacobian{T}(instance::C, f_refs::Union{Vector{UInt32}, Tuple{Symbol, Vector{UInt32}}}, x_refs::Union{Vector{UInt32}, Symbol}) where {C, T}

@assert !isa(f_refs, Tuple) || f_refs[1] == :indicators "`f_refs` is Tuple, it must be `:indicators`"
@assert !isa(x_refs, Symbol) || x_refs == :time "`x_refs` is Symbol, it must be `:time`"
Expand All @@ -1046,7 +1048,7 @@ mutable struct FMUJacobian{C, T, F} <: FMUSensitivities

inst = new{C, T, F}()
inst.f = f
inst.component = component
inst.instance = instance
inst.f_refs = f_refs
inst.f_refs_set = f_refs_set
inst.x_refs = x_refs
Expand All @@ -1068,7 +1070,7 @@ end
mutable struct FMUGradient{C, T, F} <: FMUSensitivities
valid::Bool
colored::Bool
component::C
instance::C

vec::Vector{T}
gvp::Vector{T}
Expand All @@ -1086,7 +1088,7 @@ mutable struct FMUGradient{C, T, F} <: FMUSensitivities
validations::Int
colorings::Int

function FMUGradient{T}(component::C, f_refs::Union{Vector{UInt32}, Tuple{Symbol, Vector{UInt32}}}, x_refs::Union{UInt32, Symbol}) where {C, T}
function FMUGradient{T}(instance::C, f_refs::Union{Vector{UInt32}, Tuple{Symbol, Vector{UInt32}}}, x_refs::Union{UInt32, Symbol}) where {C, T}

@assert !isa(f_refs, Tuple) || f_refs[1] == :indicators "`f_refs` is Tuple, it must be `:indicators`"
@assert !isa(x_refs, Symbol) || x_refs == :time "`x_refs` is Symbol, it must be `:time`"
Expand All @@ -1109,7 +1111,7 @@ mutable struct FMUGradient{C, T, F} <: FMUSensitivities

inst = new{C, T, F}()
inst.f = f
inst.component = component
inst.instance = instance
inst.f_refs = f_refs
inst.f_refs_set = f_refs_set
inst.x_refs = x_refs
Expand All @@ -1129,26 +1131,26 @@ mutable struct FMUGradient{C, T, F} <: FMUSensitivities
end

function f_∂v_∂v(jac::FMUJacobian, dx, x)
setReal(jac.component, jac.x_refs, x; track=false)
getReal!(jac.component, jac.f_refs, dx)
setReal(jac.instance, jac.x_refs, x; track=false)
getReal!(jac.instance, jac.f_refs, dx)
return dx
end

function f_∂e_∂v(jac::FMUJacobian, dx, x)
setReal(jac.component, jac.x_refs, x; track=false)
getEventIndicators!(jac.component, dx, jac.f_refs[2])
setReal(jac.instance, jac.x_refs, x; track=false)
getEventIndicators!(jac.instance, dx, jac.f_refs[2])
return dx
end

function f_∂e_∂t(jac::FMUGradient, dx, x)
setTime(jac.component, x; track=false)
getEventIndicators!(jac.component, dx, jac.f_refs[2])
setTime(jac.instance, x; track=false)
getEventIndicators!(jac.instance, dx, jac.f_refs[2])
return dx
end

function f_∂v_∂t(jac::FMUGradient, dx, x)
setTime(jac.component, x; track=false)
getReal!(jac.component, jac.f_refs, dx)
setTime(jac.instance, x; track=false)
getReal!(jac.instance, jac.f_refs, dx)
return dx
end

Expand Down Expand Up @@ -1191,25 +1193,25 @@ function validate!(jac::FMUJacobian, x::AbstractVector)
rows = length(jac.f_refs)
cols = length(jac.x_refs)

if jac.component.fmu.executionConfig.sensitivity_strategy == :FMIDirectionalDerivative && providesDirectionalDerivatives(jac.component.fmu) && !isa(jac.f_refs, Tuple) && !isa(jac.x_refs, Symbol)
if jac.instance.fmu.executionConfig.sensitivity_strategy == :FMIDirectionalDerivative && providesDirectionalDerivatives(jac.instance.fmu) && !isa(jac.f_refs, Tuple) && !isa(jac.x_refs, Symbol)
# ToDo: use directional derivatives with sparsitiy information!
# ToDo: Optimize allocation (onehot)
# [Note] Jacobian is sampled column by column
for i in 1:cols
getDirectionalDerivative!(jac.component, jac.f_refs, jac.x_refs, onehot(jac.component, cols, i), view(jac.mtx, 1:rows, i))
getDirectionalDerivative!(jac.instance, jac.f_refs, jac.x_refs, onehot(jac.instance, cols, i), view(jac.mtx, 1:rows, i))
end
elseif jac.component.fmu.executionConfig.sensitivity_strategy == :FMIAdjointDerivative && providesAdjointDerivatives(jac.component.fmu) && !isa(jac.f_refs, Tuple) && !isa(jac.x_refs, Symbol)
elseif jac.instance.fmu.executionConfig.sensitivity_strategy == :FMIAdjointDerivative && providesAdjointDerivatives(jac.instance.fmu) && !isa(jac.f_refs, Tuple) && !isa(jac.x_refs, Symbol)
# ToDo: use directional derivatives with sparsitiy information!
# ToDo: Optimize allocation (onehot)
# [Note] Jacobian is sampled row by row
for i in 1:rows
getAdjointDerivative!(jac.component, jac.f_refs, jac.x_refs, onehot(jac.component, rows, i), view(jac.mtx, 1:cols, i))
getAdjointDerivative!(jac.instance, jac.f_refs, jac.x_refs, onehot(jac.instance, rows, i), view(jac.mtx, 1:cols, i))
end
else #if jac.component.fmu.executionConfig.sensitivity_strategy == :FiniteDiff
else #if jac.instance.fmu.executionConfig.sensitivity_strategy == :FiniteDiff
# cache = FiniteDiff.JacobianCache(x)
FiniteDiff.finite_difference_jacobian!(jac.mtx, (_x, _dx) -> (jac.f(jac, _x, _dx)), x) # , cache)
# else
# @assert false "Unknown sensitivity strategy `$(jac.component.fmu.executionConfig.sensitivity_strategy)`."
# @assert false "Unknown sensitivity strategy `$(jac.instance.fmu.executionConfig.sensitivity_strategy)`."
end

jac.validations += 1
Expand All @@ -1219,10 +1221,10 @@ end

function validate!(grad::FMUGradient, x::Real)

if grad.component.fmu.executionConfig.sensitivity_strategy == :FMIDirectionalDerivative && providesDirectionalDerivatives(grad.component.fmu) && !isa(grad.f_refs, Tuple) && !isa(grad.x_refs, Symbol)
if grad.instance.fmu.executionConfig.sensitivity_strategy == :FMIDirectionalDerivative && providesDirectionalDerivatives(grad.instance.fmu) && !isa(grad.f_refs, Tuple) && !isa(grad.x_refs, Symbol)
# ToDo: use directional derivatives with sparsitiy information!
getDirectionalDerivative!(grad.component, grad.f_refs, grad.x_refs, ones(length(jac.f_refs)), grad.vec)
else #if grad.component.fmu.executionConfig.sensitivity_strategy == :FiniteDiff
getDirectionalDerivative!(grad.instance, grad.f_refs, grad.x_refs, ones(length(jac.f_refs)), grad.vec)
else #if grad.instance.fmu.executionConfig.sensitivity_strategy == :FiniteDiff
# cache = FiniteDiff.GradientCache(x)
FiniteDiff.finite_difference_gradient!(grad.vec, (_x, _dx) -> (grad.f(grad, _x, _dx)), x) # , cache)
end
Expand Down

0 comments on commit 00197e8

Please sign in to comment.