Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed May 13, 2024
1 parent 600d51e commit 13492b7
Show file tree
Hide file tree
Showing 89 changed files with 1,182 additions and 0 deletions.
231 changes: 231 additions & 0 deletions examples/regularization/kan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
#
using NeuralROMs
using LinearAlgebra, ComponentArrays # arrays
using Random, Lux, MLUtils, ParameterSchedulers # ML
using OptimizationOptimJL, OptimizationOptimisers # opt
using LinearSolve, NonlinearSolve, LineSearches # num
using JLD2 # vis / save
using CUDA, LuxCUDA, KernelAbstractions # GPU
using LaTeXStrings

using KolmogorovArnold

CUDA.allowscalar(false)

# using FFTW
begin
nt = Sys.CPU_THREADS
nc = min(nt, length(Sys.cpu_info()))

BLAS.set_num_threads(nc)
# FFTW.set_num_threads(nt)
end

include(joinpath(pkgdir(NeuralROMs), "examples", "cases.jl"))

#======================================================#
function uData(x; σ = 1.0f0)
pi32 = Float32(pi)

# @. tanh(2f0 * x)
# @. sin(1f0 * x)

# @. sin(5f0 * x^1) * exp(-(x/σ)^2)
# @. sin(3f0 * x^2) * exp(-(x/σ)^2)

@. (x - pi32/2f0) * sin(x) * exp(-(x/σ)^2)
end

function datagen_reg(_N, datafile; N_ = 32768)
pi32 = Float32(pi)
L = 2pi32

_x = LinRange(-L, L, _N) |> Array
x_ = LinRange(-L, L, N_) |> Array

_u = uData(_x)
u_ = uData(x_)
metadata = (;)

_data = (_x, _u)
data_ = (x_, u_)

jldsave(datafile; _data, data_, metadata)

filename = joinpath(dirname(datafile), "plt_data")

plt = plot(_x, _u, w = 3)
png(plt, filename)

plt
end
#======================================================#

function post_kan(
datafile::String,
modelfile::String,
)
data = jldopen(datafile)
x, ũ = data["data_"]
close(data)

model = jldopen(modelfile)
NN, p, st = model["model"]
md = model["metadata"]
close(model)

@show Lux.parameterlength(NN)
@show md

xbatch = reshape(x, 1, :)
model = NeuralModel(NN, st, md)

autodiff = AutoForwardDiff()
ϵ = nothing

u, ud1x = dudx4_1D(model, xbatch, p; autodiff, ϵ) .|> vec
ũ, ũd1x = forwarddiff_deriv4(uData, x)

begin
ud0_den = mse(u , 0*u) |> sqrt
ud1_den = mse(ũd1x, 0*u) |> sqrt

ud0x_relrmse_er = sqrt(mse(u , ũ )) / ud0_den
ud1x_relrmse_er = sqrt(mse(ud1x, ũd1x)) / ud1_den

ud0x_relinf_er = norm(u - ũ , Inf) / ud0_den
ud1x_relinf_er = norm(ud1x - ũd1x, Inf) / ud1_den

@show round.((ud0x_relrmse_er, ud0x_relinf_er), sigdigits = 8)
@show round.((ud1x_relrmse_er, ud1x_relinf_er), sigdigits = 8)
end

p0 = plot(xabel = "x", title = "u(x,t)" , legend = false)
p1 = plot(xabel = "x", title = "u'(x,t)", legend = false)

plot!(p0, x, ũ, label = "Ground Truth" , w = 4, c = :black)
plot!(p0, x, u, label = "Prediction" , w = 2, c = :red)

# plot!(p1, x, ũd1x, label = "Ground Truth", w = 4, c = :black)
# plot!(p1, x, ud1x, label = "Prediction", w = 2, c = :red)

plot(p0)
end

#======================================================#
function train_kan(
datafile::String,
dir::String;
rng::Random.AbstractRNG = Random.default_rng(),
device = Lux.cpu_device(),
)
#--------------------------------------------#
# get data
#--------------------------------------------#

data = jldopen(datafile)
_data = data["_data"]
data_ = data["data_"]
md_data = data["metadata"]
close(data)

_x, _u = reshape.(_data, 1, :)
x_, u_ = reshape.(data_, 1, :)

# normalize
_x, x̄, σx = normalize_x(_x)
_u, ū, σu = normalize_u(_u)

x_, x̄, σx = normalize_x(x_)
u_, ū, σu = normalize_u(u_)

# metadata
metadata = (; md_data, x̄, ū, σx, σu)

_data = (_x, _u)

#--------------------------------------------#
# architecture hyper-params
#--------------------------------------------#

wi, wo = 1, 1

G = 5
h = 2
wh = 10

in_layer = KDense(wi, wh, G; use_base_act = false)
hd_layer = KDense(wh, wh, G; use_base_act = false)
fn_layer = KDense(wh, wo, G; use_base_act = false)

NN = Chain(in_layer, fill(hd_layer, h)..., fn_layer)

#--------------------------------------------#
# training hyper-params
#--------------------------------------------#
_batchsize = 64
E = 500

lossfun = mse
_batchsize = 128

# lrs = (1f-3, 5f-4, 2f-4, 1f-4, 5f-5, 2f-5, 1f-5,)
lrs = (1f-3, 5f-4, 2f-4, 1f-4, 5f-5, 2f-5, 1f-5,)
opts = Tuple(Optimisers.Adam(lr) for lr in lrs)
Nlrs = length(lrs)

nepochs = (round.(Int, E / (Nlrs) * ones(Nlrs))...,)
schedules = Step.(lrs, 1f0, Inf32)
early_stoppings = (fill(true, Nlrs)...,)

# BFGS
nepochs = (nepochs..., E,)
opts = (opts..., LBFGS(),)
schedules = (schedules..., Step(0f0, 1f0, Inf32),)
early_stoppings = (early_stoppings..., true)

#--------------------------------------------#
# train
#--------------------------------------------#
display(NN)

train_args = (; G, h, wh, E, _batchsize)
metadata = (; metadata..., train_args)

@show metadata

@time model, ST = train_model(NN, _data; rng,
_batchsize, opts, nepochs, schedules, early_stoppings,
device, dir, metadata, lossfun,
)

# @show metadata

model, ST
end

#======================================================#
# main
#======================================================#
rng = Random.default_rng()
Random.seed!(rng, 123)

datafile = joinpath(@__DIR__, "data_reg.jld2")
modeldir = joinpath(@__DIR__, "kan")
modelfile = joinpath(modeldir, "model_04.jld2")
device = Lux.gpu_device()

E = 100
_N, N_ = 1024, 8192 # 512, 32768
_batchsize = 32

datagen_reg(_N, datafile; N_) |> display

isdir(modeldir) && rm(modeldir, recursive = true)

model, ST = train_kan(datafile, modeldir; rng, device)
plt = post_kan(datafile, modelfile)
display(plt)

#======================================================#
nothing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/regularization/kan/plt_training_02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/regularization/kan/plt_training_03.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/regularization/kan/plt_training_05.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/regularization/kan/plt_training_06.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/regularization/kan/plt_training_07.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/regularization/kan/plt_training_08.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions examples/regularization/kan/statistics_01.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Epoch [0 / 0] TRAIN LOSS: 5.1711064e-5 || TEST LOSS: 5.1711064e-5
#======================#
TRAIN STATS
R² score: 0.99994826
MSE (mean SQR error): 5.1711064e-5
RMSE (Root MSE): 0.007191041
MAE (mean ABS error): 0.0053346306
maxAE (max ABS error) 0.026876211
Lipschitz bound: 1.0

#======================#
#======================#
TEST STATS
R² score: 0.99994826
MSE (mean SQR error): 5.171106e-5
RMSE (Root MSE): 0.00719104
MAE (mean ABS error): 0.0053346306
maxAE (max ABS error) 0.026876211
Lipschitz bound: 1.0

#======================#
21 changes: 21 additions & 0 deletions examples/regularization/kan/statistics_02.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Epoch [0 / 0] TRAIN LOSS: 2.9061643e-5 || TEST LOSS: 2.9061643e-5
#======================#
TRAIN STATS
R² score: 0.99997085
MSE (mean SQR error): 2.9061643e-5
RMSE (Root MSE): 0.0053908853
MAE (mean ABS error): 0.003771622
maxAE (max ABS error) 0.01990056
Lipschitz bound: 1.0

#======================#
#======================#
TEST STATS
R² score: 0.99997085
MSE (mean SQR error): 2.906164e-5
RMSE (Root MSE): 0.005390885
MAE (mean ABS error): 0.003771622
maxAE (max ABS error) 0.01990056
Lipschitz bound: 1.0

#======================#
21 changes: 21 additions & 0 deletions examples/regularization/kan/statistics_03.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Epoch [0 / 0] TRAIN LOSS: 2.1567148e-5 || TEST LOSS: 2.1567148e-5
#======================#
TRAIN STATS
R² score: 0.9999784
MSE (mean SQR error): 2.1567148e-5
RMSE (Root MSE): 0.004644044
MAE (mean ABS error): 0.00318697
maxAE (max ABS error) 0.017391086
Lipschitz bound: 1.0

#======================#
#======================#
TEST STATS
R² score: 0.9999784
MSE (mean SQR error): 2.156715e-5
RMSE (Root MSE): 0.0046440447
MAE (mean ABS error): 0.0031869705
maxAE (max ABS error) 0.017391086
Lipschitz bound: 1.0

#======================#
21 changes: 21 additions & 0 deletions examples/regularization/kan/statistics_04.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Epoch [0 / 0] TRAIN LOSS: 1.7827306e-5 || TEST LOSS: 1.7827306e-5
#======================#
TRAIN STATS
R² score: 0.9999822
MSE (mean SQR error): 1.7827306e-5
RMSE (Root MSE): 0.0042222394
MAE (mean ABS error): 0.0028674833
maxAE (max ABS error) 0.015059233
Lipschitz bound: 1.0

#======================#
#======================#
TEST STATS
R² score: 0.9999822
MSE (mean SQR error): 1.7827308e-5
RMSE (Root MSE): 0.00422224
MAE (mean ABS error): 0.0028674842
maxAE (max ABS error) 0.015059233
Lipschitz bound: 1.0

#======================#
21 changes: 21 additions & 0 deletions examples/regularization/kan/statistics_05.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Epoch [0 / 0] TRAIN LOSS: 1.5845942e-5 || TEST LOSS: 1.5845942e-5
#======================#
TRAIN STATS
R² score: 0.99998415
MSE (mean SQR error): 1.5845942e-5
RMSE (Root MSE): 0.003980696
MAE (mean ABS error): 0.0027036269
maxAE (max ABS error) 0.014859676
Lipschitz bound: 1.0

#======================#
#======================#
TEST STATS
R² score: 0.99998415
MSE (mean SQR error): 1.5845944e-5
RMSE (Root MSE): 0.0039806967
MAE (mean ABS error): 0.0027036273
maxAE (max ABS error) 0.014859676
Lipschitz bound: 1.0

#======================#
21 changes: 21 additions & 0 deletions examples/regularization/kan/statistics_06.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Epoch [0 / 0] TRAIN LOSS: 1.4855471e-5 || TEST LOSS: 1.4855471e-5
#======================#
TRAIN STATS
R² score: 0.9999851
MSE (mean SQR error): 1.4855471e-5
RMSE (Root MSE): 0.0038542796
MAE (mean ABS error): 0.0025964177
maxAE (max ABS error) 0.014494538
Lipschitz bound: 1.0

#======================#
#======================#
TEST STATS
R² score: 0.9999851
MSE (mean SQR error): 1.485547e-5
RMSE (Root MSE): 0.0038542796
MAE (mean ABS error): 0.0025964177
maxAE (max ABS error) 0.014494538
Lipschitz bound: 1.0

#======================#
21 changes: 21 additions & 0 deletions examples/regularization/kan/statistics_07.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Epoch [0 / 0] TRAIN LOSS: 1.4253963e-5 || TEST LOSS: 1.4253963e-5
#======================#
TRAIN STATS
R² score: 0.99998575
MSE (mean SQR error): 1.4253963e-5
RMSE (Root MSE): 0.003775442
MAE (mean ABS error): 0.002528878
maxAE (max ABS error) 0.01413095
Lipschitz bound: 1.0

#======================#
#======================#
TEST STATS
R² score: 0.99998575
MSE (mean SQR error): 1.4253962e-5
RMSE (Root MSE): 0.003775442
MAE (mean ABS error): 0.002528878
maxAE (max ABS error) 0.01413095
Lipschitz bound: 1.0

#======================#
Loading

0 comments on commit 13492b7

Please sign in to comment.