Skip to content

Commit

Permalink
updated figure plots
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Apr 22, 2024
1 parent 53c2d14 commit 3a6afd3
Show file tree
Hide file tree
Showing 36 changed files with 158 additions and 65 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Manifest.toml
*.jld2
*.hdf5
*.h5
*.eps
1 change: 1 addition & 0 deletions examples/cases.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
using NeuralROMs
using JLD2, TSne
using Random, Lux, NNlib, MLUtils
using Plots, ColorSchemes, LaTeXStrings

#======================================================#
Expand Down
221 changes: 156 additions & 65 deletions figs/makefigs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,52 @@ using CairoMakie
function makeplots(
datafile,
outdir::String,
casename::AbstractString,
casename::AbstractString;
ifcrom::Bool = false,
ifdt::Bool = false,
)

data = h5open(datafile)
xFOM = data["xFOM"] |> Array # [in_dim, grid...]
tFOM = data["tFOM"] |> Array # [Nt]
uFOM = data["uFOM"] |> Array # [out_dim, grid..., Nt]
#
uPCA = data["uPCA"] |> Array
uCAE = data["uCAE"] |> Array
uSNL = data["uSNL"] |> Array
uSNW = data["uSNW"] |> Array
pCAE = data["pCAE"] |> Array
#
pCAE = data["pCAE"] |> Array # dynamics solve
pSNL = data["pSNL"] |> Array
pSNW = data["pSNW"] |> Array
#
qCAE = data["qCAE"] |> Array # encder prediction
qSNL = data["qSNL"] |> Array
qSNW = data["qSNW"] |> Array

# C-ROM
uCRM = uFOM * NaN
pCRM = pSNW * NaN
qCRM = qSNW * NaN

if ifcrom
uCRM = data["uCRM"] |> Array
pCRM = data["pCRM"] |> Array
qCRM = data["qCRM"] |> Array
end

# DT
pdtCAE, pdtSNL, pdtSNW = if ifdt
pdtCAE = data["pdtCAE"] |> Array
pdtSNL = data["pdtSNL"] |> Array
pdtSNW = data["pdtSNW"] |> Array

pdtCAE, pdtSNL, pdtSNW
else
nothing, nothing, nothing
end

#======================================================#

in_dim = size(xFOM, 1)
out_dim = size(uFOM, 1)
Expand All @@ -42,6 +74,7 @@ function makeplots(
uCAE = uCAE[1, ii...]
uSNL = uSNL[1, ii...]
uSNW = uSNW[1, ii...]
uCRM = uCRM[1, ii...]

## normalize
nr = sum(abs2, uFOM; dims = 1:in_dim) ./ prod(size(uFOM)[1:in_dim]) .|> sqrt
Expand All @@ -50,32 +83,42 @@ function makeplots(
eCAE = (uFOM - uCAE) ./ nr
eSNL = (uFOM - uSNL) ./ nr
eSNW = (uFOM - uSNW) ./ nr
eCRM = (uFOM - uCRM) ./ nr

e2tPCA = sum(abs2, ePCA; dims = 1:in_dim) / Nxyz |> vec
e2tCAE = sum(abs2, eCAE; dims = 1:in_dim) / Nxyz |> vec
e2tSNL = sum(abs2, eSNL; dims = 1:in_dim) / Nxyz |> vec
e2tSNW = sum(abs2, eSNW; dims = 1:in_dim) / Nxyz |> vec
e2tCRM = sum(abs2, eCRM; dims = 1:in_dim) / Nxyz |> vec

e2tPCA = sqrt.(e2tPCA) .+ 1f-12
e2tCAE = sqrt.(e2tCAE) .+ 1f-12
e2tSNL = sqrt.(e2tSNL) .+ 1f-12
e2tSNW = sqrt.(e2tSNW) .+ 1f-12
e2tCRM = sqrt.(e2tCRM) .+ 1f-12

idx = collect(Colon() for _ in 1:in_dim)
eitPCA = collect(norm(ePCA[idx..., i]) for i in 1:Nt)
eitCAE = collect(norm(eCAE[idx..., i]) for i in 1:Nt)
eitSNL = collect(norm(eSNL[idx..., i]) for i in 1:Nt)
eitSNW = collect(norm(eSNW[idx..., i]) for i in 1:Nt)
eitCRM = collect(norm(eCRM[idx..., i]) for i in 1:Nt)

upreds = (uPCA, uCAE, uSNL, uSNW,)
epreds = (ePCA, eCAE, eSNL, eSNW,)
upreds = (uPCA, uCAE, uSNL, uSNW,)
epreds = (ePCA, eCAE, eSNL, eSNW,)
eitpreds = (eitPCA, eitCAE, eitSNL, eitSNW,)
e2tpreds = (e2tPCA, e2tCAE, e2tSNL, e2tSNW,)

if ifcrom
upreds = (upreds..., uCRM,)
epreds = (epreds..., uCRM,)
eitpreds = (eitpreds..., eitCRM,)
e2tpreds = (e2tpreds..., e2tCRM,)
end

figt = Figure(; size = ( 900, 400), backgroundcolor = :white, grid = :off)
figc = Figure(; size = (1000, 800), backgroundcolor = :white, grid = :off)
fige = Figure(; size = ( 600, 400), backgroundcolor = :white, grid = :off)
# fige = Figure(; size = ( 900, 400), backgroundcolor = :white, grid = :off)
figp = Figure(; size = (1200, 400), backgroundcolor = :white, grid = :off)

axt0 = Axis(figt[1,1]; xlabel = L"x", ylabel = L"u(x, t)", xlabelsize = 20, ylabelsize = 20)
Expand All @@ -85,42 +128,50 @@ function makeplots(
# axe2 = Axis(fige[1,2]; xlabel = L"t", ylabel = L"ε_\infty(t)", yscale = log10, xlabelsize = 20, ylabelsize = 20)
axe2 = Axis(Figure()[1,1])

axp1, axp2, axp3 = nothing, nothing, nothing
#===============================#
# FIGP
#===============================#

sc1_kw = (; color = tFOM, colormap = :reds, label = "Dynamic solve")
sckwq = (; color = :black, markersize = 25, marker = :star5 , label = L"$\tilde{u}(t=0)$ prediction")
lnkwq = (; color = :red , linewidth = 4, linestyle = :solid, label = L"\text{Prediction}")
lnkwp = (; color = :blue , linewidth = 6, linestyle = :dot , label = L"\text{Dynamics solve}")
lnkwt = (; color = :green, linewidth = 6, linestyle = :dash , label = L"\text{Dynamics solve (Large }\Delta t)")

if size(pCAE, 1) == 1
axp1 = Axis(figp[1,1]; xlabel = L"t", ylabel = L"\tilde{u}(t)", xlabelsize = 20, ylabelsize = 20)
axp2 = Axis(figp[1,2]; xlabel = L"t", ylabel = L"\tilde{u}(t)", xlabelsize = 20, ylabelsize = 20)
axp3 = Axis(figp[1,3]; xlabel = L"t", ylabel = L"\tilde{u}(t)", xlabelsize = 20, ylabelsize = 20)
axkwp = if size(pCAE, 1) == 2
(; xlabel = L"\tilde{u}_1(t)", ylabel = L"\tilde{u}_2(t)", xlabelsize = 20, ylabelsize = 20)
elseif size(pCAE, 1) == 1
axkwp = (; xlabel = L"t", ylabel = L"\tilde{u}(t)", xlabelsize = 20, ylabelsize = 20)
end

s1 = scatter!(axp1, vec(pCAE); sc1_kw...)
s2 = scatter!(axp2, vec(pSNL); sc1_kw...)
s3 = scatter!(axp3, vec(pSNW); sc1_kw...)
axp1 = Axis(figp[1,1]; axkwp...)
axp2 = Axis(figp[1,2]; axkwp...)
axp3 = Axis(figp[1,3]; axkwp...)

axislegend(axp1)
axislegend(axp2)
axislegend(axp3)
pplot!(axp1, tFOM, pCAE, qCAE, pdtCAE; ifdt, sckwq, lnkwq, lnkwp, lnkwt)
pplot!(axp2, tFOM, pSNL, qSNL, pdtSNL; ifdt, sckwq, lnkwq, lnkwp, lnkwt)
pplot!(axp3, tFOM, pSNW, qSNW, pdtSNW; ifdt, sckwq, lnkwq, lnkwp, lnkwt)

elseif size(pCAE, 1) == 2
axp1 = Axis(figp[1,1]; xlabel = L"\tilde{u}_1(t)", ylabel = L"\tilde{u}_2(t)", xlabelsize = 20, ylabelsize = 20)
axp2 = Axis(figp[1,2]; xlabel = L"\tilde{u}_1(t)", ylabel = L"\tilde{u}_2(t)", xlabelsize = 20, ylabelsize = 20)
axp3 = Axis(figp[1,3]; xlabel = L"\tilde{u}_1(t)", ylabel = L"\tilde{u}_2(t)", xlabelsize = 20, ylabelsize = 20)
Label(figp[2,1], L"(a)")
Label(figp[2,2], L"(b)")
Label(figp[2,3], L"(c)")

s1 = scatter!(axp1, pCAE; sc1_kw...)
s2 = scatter!(axp2, pSNL; sc1_kw...)
s3 = scatter!(axp3, pSNW; sc1_kw...)
colsize!(figp.layout, 1, Relative(0.33))
colsize!(figp.layout, 2, Relative(0.33))
colsize!(figp.layout, 3, Relative(0.33))

# axislegend(axp1)
# axislegend(axp2)
# axislegend(axp3)

figp[3,:] = Legend(figp, axp1; orientation = :horizontal, patchsize = (50, 10))

#===============================#
# FIGT, FIGE, FIGC
#===============================#

axislegend(axp1)#; position = :lt, patchsize = (30, 10))
axislegend(axp2)#; position = :lt, patchsize = (30, 10))
axislegend(axp3)#; position = :lt, patchsize = (30, 10))
else
@error "latent size size(p, 1) == $(size(pCAE, 1)) not supported."
end

colors = (:orange, :green, :blue, :red, :brown,)
styles = (:solid, :dash, :dashdot, :dashdotdot,)
labels = ("POD", "CAE", "SNFL", "SNFW")
styles = (:solid, :dash, :dashdot, :dashdotdot, :dot)
labels = ("POD", "CAE", "SNFL", "SNFW", "CROM")

levels = if occursin("exp2", casename)
n = 11
Expand Down Expand Up @@ -258,12 +309,12 @@ function makeplots(

linkaxes!(axt0, axt1)

save(joinpath(outdir, casename * "p1.png"), figt)
save(joinpath(outdir, casename * "p2.png"), fige)

if in_dim == 2
save(joinpath(outdir, casename * "p3.png"), figc)
end
# save(joinpath(outdir, casename * "p1.eps"), figt)
# save(joinpath(outdir, casename * "p2.eps"), fige)
#
# if in_dim == 2
# save(joinpath(outdir, casename * "p3.eps"), figc)
# end

save(joinpath(outdir, casename * "p4.png"), figp)

Expand All @@ -274,29 +325,23 @@ end
function makeplot_exp3(
datafiles::String...;
outdir::String,
ifdt::Bool = false,
)
figp = Figure(; size = (1200, 400), backgroundcolor = :white, grid = :off)

mshape = (:circle, :utriangle, :diamond, :dtriangle, :star5,)

ax_kw = (;
axkwp = (;
xlabel = L"\tilde{u}_1(t)",
ylabel = L"\tilde{u}_2(t)",
xlabelsize = 20,
ylabelsize = 20,
)

## TODO: are there 2D color gradients?
## TODO: fix scatter plot legend. It all looks black rn.
## TODO: if cannot be fixed, replace with lines + markers

axp1 = Axis(figp[1,1]; ax_kw...)
axp2 = Axis(figp[1,2]; ax_kw...)
axp3 = Axis(figp[1,3]; ax_kw...)
axp1 = Axis(figp[1,1]; axkwp...)
axp2 = Axis(figp[1,2]; axkwp...)
axp3 = Axis(figp[1,3]; axkwp...)

# is there a 2D colormap ??
cmaps = (:reds, :greens, :winter, :viridis, :vik, :blues,)
label = (
labels = (
L"$\mu = 0.500$ (Training)",
L"$\mu = 0.525$ (Interpolation)",
L"$\mu = 0.550$ (Training)",
Expand All @@ -305,24 +350,70 @@ function makeplot_exp3(
L"$\mu = 0.625$ (Extrapolation)",
)

colors = (:blue, :orange, :green, :red, :purple, :brown,)

for (i, datafile) in enumerate(datafiles)
data = h5open(datafile)
tFOM = data["tFOM"] |> Array # [Nt]
#
pCAE = data["pCAE"] |> Array
pSNL = data["pSNL"] |> Array
pSNW = data["pSNW"] |> Array
#
qCAE = data["qCAE"] |> Array # encder prediction
qSNL = data["qSNL"] |> Array
qSNW = data["qSNW"] |> Array

sc_kw = (; color = tFOM, colormap = cmaps[i], label = label[i])
sc_kw = (; label = label[i])
color = colors[i]
label = labels[i]
sckwq = (; color = :black, markersize = 15, marker = :star5,)
lnkwq = (; color, label, linewidth = 4, linestyle = :solid,)
lnkwp = (; color, linewidth = 6, linestyle = :dot,)

scatter!(axp1, pCAE; sc_kw...)
scatter!(axp2, pSNL; sc_kw...)
scatter!(axp3, pSNW; sc_kw...)
pplot!(axp1, tFOM, pCAE, qCAE; sckwq, lnkwq, lnkwp)
pplot!(axp2, tFOM, pSNL, qSNL; sckwq, lnkwq, lnkwp)
pplot!(axp3, tFOM, pSNW, qSNW; sckwq, lnkwq, lnkwp)
end

figp[2,:] = Legend(figp, axp1; orientation = :horizontal,) # patchsize = (100, 10))
Label(figp[2,1], L"(a)")
Label(figp[2,2], L"(b)")
Label(figp[2,3], L"(c)")

colsize!(figp.layout, 1, Relative(0.33))
colsize!(figp.layout, 2, Relative(0.33))
colsize!(figp.layout, 3, Relative(0.33))

figp[3,:] = Legend(figp, axp1; orientation = :horizontal,) # patchsize = (100, 10))

save(joinpath(outdir, "exp3p.eps"), figp)
end

#======================================================#
function pplot!(ax, t, p, q, pdt = nothing;
ifdt = false,
sckwq = (;),
lnkwq = (;),
lnkwp = (;),
lnkwt = (;),
)
if size(p, 1) == 2
scatter!(ax, q[:, 1:1]; sckwq...)
lines!(ax, q; lnkwq...)
lines!(ax, p; lnkwp...)
if ifdt
lines!(ax, pdt; lnkwt...)
end
elseif size(p, 1) == 1
scatter!(ax, [0f0,], q[:, 1]; sckwq...)
lines!(ax, t, vec(q); lnkwq...)
lines!(ax, t, vec(p); lnkwp...)

save(joinpath(outdir, "exp3p.png"), figp)
if ifdt
lines!(ax, t, vec(pdt); lnkwt...)
end
else
@error "latent size size(p, 1) == $(size(p, 1)) not supported."
end
end

#======================================================#
Expand Down Expand Up @@ -364,19 +455,19 @@ e3file4 = joinpath(h5dir, "burgers1dcase4.h5")
e3file5 = joinpath(h5dir, "burgers1dcase5.h5")
e3file6 = joinpath(h5dir, "burgers1dcase6.h5")

# makeplots(e1file, outdir, "exp1")
makeplots(e2file, outdir, "exp2")
makeplots(e4file, outdir, "exp4")
makeplots(e1file, outdir, "exp1"; ifdt = true)
# makeplots(e2file, outdir, "exp2")
# makeplots(e4file, outdir, "exp4")
# makeplots(e5file, outdir, "exp5")
#
# makeplots(e3file1, outdir, "exp3case1")
# makeplots(e3file2, outdir, "exp3case2")
# makeplots(e3file3, outdir, "exp3case3")
# # makeplots(e3file1, outdir, "exp3case1")
# # makeplots(e3file3, outdir, "exp3case3")
# # makeplots(e3file2, outdir, "exp3case2")
# makeplots(e3file4, outdir, "exp3case4")
# makeplots(e3file5, outdir, "exp3case5")
# makeplots(e3file6, outdir, "exp3case6")

# makeplot_exp3(e3file1, e3file2, e3file3, e3file4, e3file5, e3file6; outdir)
makeplot_exp3(e3file1, e3file2, e3file3, e3file4, e3file5, e3file6; outdir)

#======================================================#
nothing
Binary file removed figs/results/exp1p1.png
Binary file not shown.
Binary file removed figs/results/exp1p2.png
Binary file not shown.
Binary file modified figs/results/exp1p4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed figs/results/exp2p1.png
Binary file not shown.
Binary file removed figs/results/exp2p2.png
Binary file not shown.
Binary file removed figs/results/exp2p3.png
Binary file not shown.
Binary file removed figs/results/exp2p4.png
Binary file not shown.
Binary file removed figs/results/exp3case1p1.png
Binary file not shown.
Binary file removed figs/results/exp3case1p2.png
Binary file not shown.
Binary file removed figs/results/exp3case1p4.png
Binary file not shown.
Binary file removed figs/results/exp3case2p1.png
Binary file not shown.
Binary file removed figs/results/exp3case2p2.png
Binary file not shown.
Binary file removed figs/results/exp3case2p4.png
Binary file not shown.
Binary file removed figs/results/exp3case3p1.png
Binary file not shown.
Binary file removed figs/results/exp3case3p2.png
Binary file not shown.
Binary file removed figs/results/exp3case3p4.png
Binary file not shown.
Binary file removed figs/results/exp3case4p1.png
Binary file not shown.
Binary file removed figs/results/exp3case4p2.png
Binary file not shown.
Binary file removed figs/results/exp3case4p4.png
Binary file not shown.
Binary file removed figs/results/exp3case5p1.png
Binary file not shown.
Binary file removed figs/results/exp3case5p2.png
Binary file not shown.
Binary file removed figs/results/exp3case5p4.png
Binary file not shown.
Binary file removed figs/results/exp3case6p1.png
Binary file not shown.
Binary file removed figs/results/exp3case6p2.png
Binary file not shown.
Binary file removed figs/results/exp3case6p4.png
Binary file not shown.
Binary file removed figs/results/exp3p.png
Diff not rendered.
Binary file removed figs/results/exp4p1.png
Diff not rendered.
Binary file removed figs/results/exp4p2.png
Diff not rendered.
Binary file removed figs/results/exp4p3.png
Diff not rendered.
Binary file removed figs/results/exp4p4.png
Diff not rendered.
Binary file removed figs/results/exp5p1.png
Diff not rendered.
Binary file removed figs/results/exp5p2.png
Diff not rendered.
Binary file removed figs/results/exp5p4.png
Diff not rendered.

0 comments on commit 3a6afd3

Please sign in to comment.