Skip to content

Commit

Permalink
optimized tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ThummeTo committed Nov 8, 2023
1 parent ae31687 commit 782047a
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 23 deletions.
4 changes: 2 additions & 2 deletions test/batching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import Random
Random.seed!(1234);

t_start = 0.0
t_step = 0.1
t_stop = 150.0
t_step = 0.01
t_stop = 50.0
tData = t_start:t_step:t_stop

# generate training data
Expand Down
6 changes: 3 additions & 3 deletions test/fmu_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import Random
Random.seed!(1234);

t_start = 0.0
t_step = 0.1
t_stop = 15.0
t_step = 0.01
t_stop = 5.0
tData = t_start:t_step:t_stop

# generate training data
Expand Down Expand Up @@ -66,7 +66,7 @@ numStates = length(fmu.modelDescription.stateValueReferences)

# the "Chain" for training
net = Chain(FMUParameterRegistrator(fmu, p_refs, p),
x -> fmu(x=x, dx=:all)) # , fmuLayer(p))
x -> fmu(x=x, dx_refs=:all)) # , fmuLayer(p))

optim = Adam(ETA)
solver = Tsit5()
Expand Down
4 changes: 2 additions & 2 deletions test/hybrid_CS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import Random
Random.seed!(1234);

t_start = 0.0
t_step = 0.1
t_stop = 15.0
t_step = 0.01
t_stop = 5.0
tData = t_start:t_step:t_stop

# generate training data
Expand Down
4 changes: 2 additions & 2 deletions test/hybrid_ME.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import Random
Random.seed!(1234);

t_start = 0.0
t_step = 0.1
t_stop = 15.0
t_step = 0.01
t_stop = 5.0
tData = t_start:t_step:t_stop

# generate training data
Expand Down
12 changes: 9 additions & 3 deletions test/hybrid_ME_dis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import Random
Random.seed!(5678);

t_start = 0.0
t_step = 0.1
t_stop = 15.0
t_step = 0.01
t_stop = 5.0
tData = t_start:t_step:t_stop

# generate training data
Expand Down Expand Up @@ -163,6 +163,7 @@ for i in 1:length(nets)

# train it ...
p_net = Flux.params(problem)
@test length(p_net) == 1

@test problem !== nothing

Expand All @@ -175,7 +176,12 @@ for i in 1:length(nets)

iterCB = 0
lastLoss = losssum(p_net[1])
@info "[ $(iterCB)] Loss: $lastLoss"
@info "Start-Loss for net #$i: $lastLoss"

if length(p_net[1]) == 0
@info "The following warning is not an issue, because training on zero parameters must throw a warning:"
end

FMIFlux.train!(losssum, p_net, Iterators.repeated((), NUMSTEPS), optim; cb=()->callb(p_net), gradient=GRADIENT)

# check results
Expand Down
4 changes: 2 additions & 2 deletions test/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import Random
Random.seed!(1234);

t_start = 0.0
t_step = 0.1
t_stop = 15.0
t_step = 0.01
t_stop = 5.0
tData = t_start:t_step:t_stop

# generate training data
Expand Down
4 changes: 2 additions & 2 deletions test/multi_threading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import Random
Random.seed!(5678);

t_start = 0.0
t_step = 0.1
t_stop = 15.0
t_step = 0.01
t_stop = 5.0
tData = t_start:t_step:t_stop

# generate training data
Expand Down
4 changes: 2 additions & 2 deletions test/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import Random
Random.seed!(1234);

t_start = 0.0
t_step = 0.1
t_stop = 15.0
t_step = 0.01
t_stop = 5.0
tData = t_start:t_step:t_stop

# generate training data
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ global x0 = [2.0, 0.0]

# training data for pendulum experiment
function syntTrainingData(tData)
posData = cos.(tData)* 1.0
velData = sin.(tData)*-1.0
accData = cos.(tData)*-1.0
posData = cos.(tData*3.0)* 2.0
velData = sin.(tData*3.0)*-6.0
accData = cos.(tData*3.0)*-18.0
return posData, velData, accData
end

Expand Down
4 changes: 2 additions & 2 deletions test/train_modes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import Random
Random.seed!(5678);

t_start = 0.0
t_step = 0.1
t_stop = 15.0
t_step = 0.01
t_stop = 5.0
tData = t_start:t_step:t_stop

# generate training data
Expand Down

0 comments on commit 782047a

Please sign in to comment.