Skip to content

Commit

Permalink
darcy problem
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Jul 28, 2023
1 parent 12352c2 commit e48baec
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 25 deletions.
18 changes: 9 additions & 9 deletions examples/darcy2D_pdebench/darcy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using HDF5, Random
using CalculustCore: ndgrid

function darcy2D(filename, K = 1024, rng = Random.default_rng())
function darcy2D(filename, _K = 1024, K_ = 256, rng = Random.default_rng())
file = h5open(filename, "r")

ν = read(file["nu"]) # [128, 128, 10000]
Expand All @@ -12,18 +12,18 @@ function darcy2D(filename, K = 1024, rng = Random.default_rng())
u = read(file["tensor"]) # [128, 128, 1, 10000]

N = length(x)
Kmax = size(u,)[end]
Ks = rand(rng, 1:Kmax, 2K)
_I = Ks[begin:K]
I_ = Ks[K+1:end]
Kmax = size(u)[end]
Ks = rand(rng, 1:Kmax, _K + K_)
_I = Ks[begin:_K]
I_ = Ks[_K+1:end]

x, y = ndgrid(x, y)

_x = zeros(3, N, N, K)
x_ = zeros(3, N, N, K)
_x = zeros(3, N, N, _K)
x_ = zeros(3, N, N, K_)

_u = zeros(1, N, N, K)
u_ = zeros(1, N, N, K)
_u = zeros(1, N, N, _K)
u_ = zeros(1, N, N, K_)

_x[1, :, :, :] .= x
_x[2, :, :, :] .= y
Expand Down
11 changes: 7 additions & 4 deletions examples/darcy2D_pdebench/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ rng = Random.default_rng()
Random.seed!(rng, 345)

N = 128
K = 512 # trajectories
E = 200 # epochs
E = 300 # epochs

# trajectories
_K = 4096
K_ = 512

# get data
dir = @__DIR__
filename = joinpath(dir, "2D_DarcyFlow_beta0.01_Train.hdf5")
include(joinpath(dir, "darcy.jl"))
_data, data_ = darcy2D(filename, K, rng)
_data, data_ = darcy2D(filename, _K, K_, rng)

V = FourierSpace(N, N)

Expand All @@ -70,7 +73,7 @@ opt = Optimisers.Adam()
batchsize = 32
learning_rates = (1f-2, 1f-3,)
nepochs = E .* (0.10, 0.90,) .|> Int
dir = joinpath(@__DIR__, "FNO2")
dir = joinpath(@__DIR__, "FNO4")
device = Lux.gpu

FNO_nl = train_model(rng, NN, _data, data_, V, opt;
Expand Down
28 changes: 16 additions & 12 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,13 @@ function (L::Loss)(p)
L.lossfun(y, ŷ), st # Lux interface
end

function grad(loss::Loss, p)
(l, st), pb = Zygote.pullback(loss, p)
gr = pb((one.(l), nothing))[1]

l, st, gr
end

"""
$SIGNATURES
Expand All @@ -328,20 +335,18 @@ function optimize(NN, p, st, loader, nepochs;
io::Union{Nothing, IO} = stdout,
)

function grad(loss, p)
(l, st), pb = Zygote.pullback(loss, p)
gr = pb((one.(l), nothing))[1]

l, gr, st
end

# print stats
!isnothing(cb) && cb(p, st, 0, nepochs; io)

function loss(x, ŷ, p, st)
y, st = NN(x, p, st)
lossfun(y, ŷ), st
end

# warm up
begin
loss = Loss(NN, st, first(loader), lossfun)
_, _, _ = grad(loss, p)
grad(loss, p)
end

# init optimizer
Expand All @@ -350,13 +355,12 @@ function optimize(NN, p, st, loader, nepochs;
for epoch in 1:nepochs
for batch in loader
loss = Loss(NN, st, batch, lossfun)

l, g, st = grad(loss, p)
opt_st, p = Optimisers.update(opt_st, p, g)
l, st, g = grad(loss, p)
opt_st, p = Optimisers.update!(opt_st, p, g)

println(io, "Epoch [$epoch / $nepochs]" * "\t Batch loss: $l")

# GC.gc(false)
GC.gc(false)
end
# GC.gc(true)

Expand Down

0 comments on commit e48baec

Please sign in to comment.