Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Jun 24, 2024
1 parent 36947fb commit 38efec8
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 50 deletions.
39 changes: 31 additions & 8 deletions experiments_SNFROM/cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,33 @@ end
function eval_model(
model::NeuralROMs.AbstractNeuralModel,
x::AbstractArray,
p::AbstractArray;
p::AbstractMatrix,
ax::ComponentArrays.Axis;
batchsize = 1,
device = Lux.gpu_device(),
)
us = []

x = x |> device
p = p |> device
model = model |> device

for i in axes(p, 2)
q = ComponentArray(p[:, i], ax)
u = eval_model(model, x, q; batchsize, device)

push!(us, u)
end

cat(us...; dims = 3)
end

function eval_model(
model::NeuralROMs.AbstractNeuralModel,
x::AbstractArray,
p::AbstractVector;
batchsize = numobs(x) ÷ 100,
device = Lux.cpu_device(),
device = Lux.gpu_device(),
)
loader = MLUtils.DataLoader(x; batchsize, shuffle = false, partial = true)

Expand All @@ -323,21 +347,20 @@ function eval_model(

y = ()
for batch in loader
yy = model(batch, p)
yy = model(batch, p) |> Lux.cpu_device()
y = (y..., yy)
end

hcat(y...) |> Lux.cpu_device()
hcat(y...)
end

function eval_model(
model::NTuple{3, Any},
x;
batchsize = numobs(x) ÷ 100,
device = Lux.cpu_device(),
device = Lux.gpu_device(),
)
NN, p, st = model

loader = MLUtils.DataLoader(x; batchsize, shuffle = false, partial = true)

p, st = (p, st) |> device
Expand All @@ -349,11 +372,11 @@ function eval_model(

y = ()
for batch in loader
yy = NN(batch, p, st)[1]
yy = NN(batch, p, st)[1] |> Lux.cpu_device()
y = (y..., yy)
end

hcat(y...) |> Lux.cpu_device()
hcat(y...)
end

#======================================================#
Expand Down
4 changes: 2 additions & 2 deletions experiments_SNFROM/convAE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function makedata_CAE(

# make arrays
_u = permutedims(_u, (2, 1, 3, 4)) # [Nx, out_dim, Nbatch, Ntime]
u_ = permutedims(_u, (2, 1, 3, 4))
u_ = permutedims(u_, (2, 1, 3, 4))

_u = reshape(_u, grid..., out_dim, _Ns)
u_ = reshape(u_, grid..., out_dim, Ns_)
Expand Down Expand Up @@ -235,7 +235,7 @@ function postprocess_CAE(
# parameter plots
linewidth = 2.0
palette = :tab10
colors = (:reds, :greens, :blues, cgrad(:thermal), cgrad(:acton), cgrad(:viridis))
colors = (:reds, :greens, :blues, cgrad(:thermal), cgrad(:acton), cgrad(:viridis), cgrad(:thermal), cgrad(:viridis))
shapes = (:circle, :square, :star,)

plt = plot(; title = "Parameter scatter plot")
Expand Down
50 changes: 10 additions & 40 deletions experiments_SNFROM/smoothNF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -540,23 +540,7 @@ function evolve_SNF(
Ud = @view Udata[:, :, case, :]

# query decoder to get output field
Up = begin
Ups = []
ax = getaxes(p0)

_X = Xdata |> device
_ps = ps |> device
_model = model |> device

for i in axes(ps, 2)
_p = ComponentArray(_ps[:, i], ax)
_u = _model(_X, _p)
push!(Ups, _u)
end

Ups = Ups |> Lux.cpu_device()
cat(Ups...; dims = 3)
end
Up = eval_model(model, Xdata, ps, getaxes(p0); device)

# print error metrics
begin
Expand Down Expand Up @@ -599,7 +583,7 @@ function hyperreduction_idx(
rng = Random.default_rng(),
tol::Real = 1f-2,
Q::Integer = 10,
maxsamples::Integer = 100,
maxsamples::Integer = 64,
verbose::Bool = false,
device = Lux.gpu_device(),
)
Expand Down Expand Up @@ -627,7 +611,6 @@ function hyperreduction_idx(
rm = residual_metric(r)

println("HYPERREDUCTION_IDX: |IX| = $(length(IX)), metric: $(rm)")
println("HYPERREDUCTION_IDX: $(IX')")

if residual_metric(r) < tol
println("HYPERREDUCTION_IDX: Tolerance has been met with $(length(IX)) points.")
Expand Down Expand Up @@ -672,8 +655,8 @@ function compute_residual(

# make data
Nt = length(Tdata)
# It = LinRange(1, Nt, 2) .|> Base.Fix1(round, Int)
It = LinRange(1, Nt, 10) .|> Base.Fix1(round, Int)
It = LinRange(1, Nt, 2) .|> Base.Fix1(round, Int)
# It = LinRange(1, Nt, 10) .|> Base.Fix1(round, Int)

Td = @view Tdata[It]
Xd = @view Xdata[:, IX]
Expand All @@ -690,23 +673,10 @@ function compute_residual(
adaptive, autodiff_xyz, ϵ_xyz, learn_ic, verbose, device,
)

# Compute residual
Up = begin
Ups = []
_X = Xdata |> device
_ps = ps |> device
_model = model |> device

for i in axes(ps, 2)
_p = ComponentArray(_ps[:, i], ax)
_u = _model(_X, _p)
push!(Ups, _u)
end

Ups = Ups |> Lux.cpu_device()
cat(Ups...; dims = 3)
end
# get prediction values
Up = eval_model(model, Xdata, ps, getaxes(p0); device)

# compute residual
err = Up - Ud
res += sum(abs2, err; dims = (1, 3)) .|> sqrt |> vec
end
Expand All @@ -716,9 +686,9 @@ end

function residual_metric(r::AbstractVector)
mn = sum(r) / length(r)
mx = maximum(r)

mn + mx
# mx = maximum(r)
#
# mn + mx
end
#===========================================================#
#

0 comments on commit 38efec8

Please sign in to comment.