Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAE script multiple parameters #161

Merged
merged 14 commits into from
Jul 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Removed doctests for now. This fails for v1.11 because of a changed rng.
benedict-96 committed Jul 2, 2024
commit 672a41cdcb8d6ddbd15e5d5a9235276a2f69e1dc
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using SafeTestsets, Test, GeometricMachineLearning
using Documenter: doctest

@testset "Doc tests " begin doctest(GeometricMachineLearning; manual = false) end
# @testset "Doc tests " begin doctest(GeometricMachineLearning; manual = false) end
# reduced order modeling tests
@safetestset "PSD tests " begin include("psd_architecture_tests.jl") end
@safetestset "SymplecticAutoencoder tests " begin include("symplectic_autoencoder_tests.jl") end

Unchanged files with check annotations Beta

if autoencoder == false
DataLoader{T, typeof(data), Nothing, :TimeSeries}(data, nothing, input_dim, input_time_steps, n_params, nothing, nothing)
elseif autoencoder == true
DataLoader{T, typeof(data), Nothing, :RegularData}(data, nothing, input_dim, input_time_steps, n_params, nothing, nothing)

Check warning on line 79 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L78-L79

Added lines #L78 - L79 were not covered by tests
end
end
if autoencoder == false
DataLoader{T, typeof(data), Nothing, :TimeSeries}(data, nothing, dim2 * 2, time_steps, n_params, nothing, nothing)
elseif autoencoder == true
DataLoader{T, typeof(data), Nothing, :RegularData}(data, nothing, dim2 * 2, time_steps, n_params, nothing, nothing)

Check warning on line 185 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L184-L185

Added lines #L184 - L185 were not covered by tests
end
end
Internally this stores the data as a tensor where the third axis has length equal to the number of solutions in the ensemble.
"""
function DataLoader(ensemble_solution::EnsembleSolution{T, T1, Vector{ST}}; autoencoder = false) where {T, T1, DT <: DataSeries{T}, ST <: GeometricSolution{T, T1, NamedTuple{(:q, :p), Tuple{DT, DT}}}}

Check warning on line 258 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L258

Added line #L258 was not covered by tests
sys_dim, input_time_steps, n_params = length(ensemble_solution.s[1].q[0]), length(ensemble_solution.t), length(ensemble_solution.s)
data = (q = zeros(T, sys_dim, input_time_steps, n_params), p = zeros(T, sys_dim, input_time_steps, n_params))
end
end
DataLoader(data; autoencoder = autoencoder)

Check warning on line 270 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L270

Added line #L270 was not covered by tests
end
function map_to_new_backend(input::AbstractArray{T}, backend::KernelAbstractions.Backend) where T
There is an optional keyword argument `autoencoder`. See the docstring for [`DataLoader(data::AbstractArray{<:Number, 3})`](@ref).
"""
function DataLoader(dl::DataLoader{T1, <:QPTOAT, Nothing}, backend::KernelAbstractions.Backend=KernelAbstractions.get_backend(dl), T::DataType=T1; autoencoder = false) where T1
input =

Check warning on line 314 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L313-L314

Added lines #L313 - L314 were not covered by tests
if T == T1
dl.input
else
map_to_type(dl.input, T)
end
new_input =

Check warning on line 321 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L321

Added line #L321 was not covered by tests
if backend == KernelAbstractions.get_backend(dl)
input

Check warning on line 323 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L323

Added line #L323 was not covered by tests
else
map_to_new_backend(input, backend)

Check warning on line 325 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L325

Added line #L325 was not covered by tests
end
if autoencoder == true

Check warning on line 328 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L328

Added line #L328 was not covered by tests
DataLoader{T, typeof(new_input), Nothing, :RegularData}(
new_input,
nothing,
nothing,
nothing)
elseif autoencoder == false
DataLoader{T, typeof(new_input), Nothing, :TimeSeries}(

Check warning on line 338 in src/data_loader/data_loader.jl

Codecov / codecov/patch

src/data_loader/data_loader.jl#L338

Added line #L338 was not covered by tests
new_input,
nothing,
dl.input_dim,
The initial conditions and parameters are taken as the first elements in the respective vector.
"""
function HRedSys(odeensemble::HODEEnsemble, encoder::NeuralNetwork{<:SymplecticEncoder}, decoder::NeuralNetwork{<:SymplecticDecoder}; integrator=ImplicitMidpoint())
N = encoder.architecture.full_dim
n = encoder.architecture.reduced_dim
v_eq = odeensemble.equation.v
f_eq = odeensemble.equation.f
h_eq = odeensemble.equation.hamiltonian
HRedSys(N, n, encoder, decoder, v_eq, f_eq, h_eq, odeensemble.tspan, odeensemble.tstep, odeensemble.ics[1]; parameters = odeensemble.parameters[1], integrator = integrator)

Check warning on line 75 in src/reduced_system/reduced_system.jl

Codecov / codecov/patch

src/reduced_system/reduced_system.jl#L69-L75

Added lines #L69 - L75 were not covered by tests
end
# this is much more expensive than it has to be and is due to a problem with nested derivatives in ForwardDiff (should not be necessary to do this twice!)