Skip to content

Commit

Permalink
advect, b urgers
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Aug 4, 2023
1 parent 048505e commit 8d75909
Show file tree
Hide file tree
Showing 15 changed files with 36 additions and 45 deletions.
18 changes: 10 additions & 8 deletions examples/diffusion_fourier/exp_bilinear_scale/bilin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ include("../datagen.jl")
N = 128 # problem size
K1 = 32 # ν-samples
K2 = 32 # f-samples
E = 200 # epochs
E = 100 # epochs

rng = Random.default_rng()
Random.seed!(rng, 199)

# datagen
_V, _data, _, _ = datagen1D(rng, N, K1, K2) # train
# V_, data_, _, _ = datagen1D(rng, N, K1, K2) # train
V_, data_, _, _ = datagen1D(rng, N, K1, K2; mode = :test) # test
V_, data_, _, _ = datagen1D(rng, N, K1, K2) # train
# V_, data_, _, _ = datagen1D(rng, N, K1, K2; mode = :test) # test

__data = combine_data1D(_data)
data__ = combine_data1D(data_)
Expand Down Expand Up @@ -100,7 +100,7 @@ w2 = 16 # width linear
wo = 8 # width project
m = (32,) # modes

split = SplitRows(1:2, 3)
split = SplitRows(1:2, 3)
nonlin = Chain(PermutedBatchNorm(c1, 3), Dense(c1, w1, tanh), OpKernel(w1, w1, m, tanh))
linear = Dense(c2, w2, use_bias = false)
bilin = OpConvBilinear(w1, w2, o, m)
Expand All @@ -109,13 +109,15 @@ bilin = OpConvBilinear(w1, w2, o, m)
NN = linear_nonlinear(split, nonlin, linear, bilin)

opt = Optimisers.Adam()
batchsize = size(__data[1])[end]
batchsize = 256 # size(__data[1])[end] # 1024
learning_rates = (1f-3,)
nepochs = E .* (1.00,) .|> Int
nepochs = E .* (1.00,) .|> Int
# learning_rates = (1f-3, 5f-4, 2.5f-4, 1.25f-4,)
# nepochs = E .* (0.25, 0.25, 0.25, 0.25,) .|> Int
dir = joinpath(@__DIR__, "exp_FNO_linear_nonlinear")
device = Lux.gpu
device = Lux.cpu

model, _ = train_model(rng, NN, __data, data__, _V, opt;
model, ST = train_model(rng, NN, __data, data__, _V, opt;
batchsize, learning_rates, nepochs, dir, cbstep = 1, device)

end
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
TRAIN LOSS: 0.0259869 TEST LOSS: 281.77728
TRAIN LOSS: 0.02262737 TEST LOSS: 0.03570642
#======================#
TRAIN STATS
R² score: 0.9733658
mean SQR error: 0.02598686
mean ABS error: 0.12003216
max ABS error: 1.0625379
R² score: 0.9730648
MSE (mean SQR error): 0.02510134
RMSE (root mean SQR error): 0.15843402
MAE (mean ABS error): 0.11758655
maxAE (max ABS error) 1.0394883

#======================#
#======================#
TEST STATS
R² score: 0.9745712
mean SQR error: 281.7768
mean ABS error: 10.313357
max ABS error: 168.86465
R² score: 0.97130257
MSE (mean SQR error): 0.03251882
RMSE (root mean SQR error): 0.18032974
MAE (mean ABS error): 0.13228929
maxAE (max ABS error) 1.1780059

#======================#
4 changes: 2 additions & 2 deletions examples/pdebench/burgers1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ K_ = 64

# get data
dir = @__DIR__
filename = joinpath(dir, "1D_Burgers_Sols_Nu0.01.hdf5")
filename = joinpath(dir, "1D_Burgers_Sols_Nu0.001.hdf5")
include(joinpath(dir, "pdebench.jl"))
_data, data_ = burgers1D(filename, _K, K_, rng)

Expand Down Expand Up @@ -83,7 +83,7 @@ nepochs = E .* (0.25, 0.25, 0.25, 0.25) .|> Int
# learning_rates = (1f-3, 5f-4, 2.5f-4, 1.25f-4)
# nepochs = E .* (0.25, 0.25, 0.25, 0.25) .|> Int

dir = joinpath(@__DIR__, "model_burgers1D")
dir = joinpath(@__DIR__, "dump")
device = Lux.cpu

model, ST = train_model(rng, NN, _data, data_, V, opt;
Expand Down
Binary file removed examples/pdebench/model_burgers1D/plt_r2_test.png
Binary file not shown.
Binary file removed examples/pdebench/model_burgers1D/plt_r2_train.png
Binary file not shown.
Binary file removed examples/pdebench/model_burgers1D/plt_training.png
Binary file not shown.
Binary file removed examples/pdebench/model_burgers1D/plt_traj_test.png
Binary file not shown.
Binary file removed examples/pdebench/model_burgers1D/plt_traj_train.png
Binary file not shown.
19 changes: 0 additions & 19 deletions examples/pdebench/model_burgers1D/statistics.txt

This file was deleted.

20 changes: 13 additions & 7 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ function OpKernel(ch_in::Int, ch_out::Int, modes::NTuple{D, Int},
conv = OpConv(ch_in, ch_out, modes; transform, init)
loc = Dense(ch_in, ch_out; init_weight = init, use_bias)

# Parallel(+, loc, conv), # x -> lox(x) + conv(x)

Chain(
Lux.Parallel(+, loc, conv),
Lux.WrappedFunction(activation),
BranchLayer(loc, conv), # x -> (loc(x), conv(x))
WrappedFunction(sum), # (x1, x2) -> x1 + x2
WrappedFunction(activation), # x -> act(x)
)
end

Expand All @@ -37,12 +40,15 @@ function OpKernelBilinear(ch_in1::Int, ch_in2::Int, ch_out::Int,

activation = fastify(activation)

null = NoOpLayer()
conv = OpConvBilinear(ch_in1, ch_in2, ch_out, modes; transform, init)
loc = Bilinear((ch_in1, ch_in2) => ch_out; init_weight = init, use_bias = false)

# Parallel(+, null, null),
Chain(
Lux.Parallel(.+, loc, conv),
Lux.WrappedFunction(activation),
BranchLayer(loc, conv), # x -> (loc(x), conv(x))
WrappedFunction(sum), # (x1, x2) -> x1 + x2
WrappedFunction(activation), # x -> act(x)
)
end

Expand All @@ -66,9 +72,9 @@ x2 → linear → y2 ↗
function linear_nonlinear(split, nonlin, linear, bilinear, project = NoOpLayer())

Chain(
split,
Parallel(nothing, nonlin, linear),
bilinear,
split, # x -> (x1, x2)
Parallel(nothing, nonlin, linear), # (x1, x2) -> (f(x1), g(x2))
bilinear, # (f(x1), g(x2)) -> y
project,
)
end
Expand Down

0 comments on commit 8d75909

Please sign in to comment.