Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
albop committed Jun 1, 2024
1 parent 2f6d73e commit 4a6d0ba
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
30 changes: 22 additions & 8 deletions misc/dev_float32.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,40 @@ model_32 = include("$(root_dir)/misc/rbc_float32.jl")

model32 = Dolo.convert_precision(Float32, model_32)

dm32 = Dolo.discretize(model32, Dict(:endo=>[1000]) )
dm32 = Dolo.discretize(model32, Dict(:endo=>[100000]) )

typeof(dm32)


using Adapt
import oneAPI: oneArray
# import oneAPI: oneArray
import CUDA: CuArray
# import Cu
import Adapt: adapt_structure
# import CUDA: CuArray

import Dolo
gpuArray = CuArray
interp_mode = :linear

wk0 = Dolo.time_iteration_workspace(dm32);
wk0 = Dolo.time_iteration_workspace(dm32; interp_mode=interp_mode);
wk1 = Dolo.time_iteration_workspace(dm32; interp_mode=interp_mode);
wk_gpu = Dolo.time_iteration_workspace(dm32, dest=gpuArray; interp_mode=interp_mode);

# wk = Dolo.time_iteration_workspace(dm32, dest=CuArray)
@time wk = Dolo.time_iteration_workspace(dm32, dest=oneArray; interp_mode=:cubic);
@time sol1 = time_iteration(dm32, wk0; engine=:nothing, tol_η=1e-5, verbose=true, improve_wait=0, improve=false);
@time sol3 = time_iteration(dm32, wk1; engine=:cpu, tol_η=1e-5, verbose=true, improve_wait=0, improve=false);
@time sol2 = time_iteration(dm32, wk_gpu; engine=:gpu, tol_η=1e-5, verbose=true, improve_wait=0, improve=false);


time_iteration(dm32, wk0; engine=:gpu, tol_η=1e-5, improve_wait=3, improve=true)
time_iteration(dm32, wk; engine=:gpu, tol_η=1e-5, improve_wait=3, improve=true)


wk0 = Dolo.time_iteration_workspace(dm32; interp_mode=:cubic);
wk1 = Dolo.time_iteration_workspace(dm32; interp_mode=:cubic);
wk_gpu = Dolo.time_iteration_workspace(dm32, dest=gpuArray; interp_mode=:cubic);

@time sol1 = time_iteration(dm32, wk0; engine=:nothing, tol_η=1e-5, verbose=true, improve_wait=10, improve_K=100,improve=true);
@time sol3 = time_iteration(dm32, wk1; engine=:cpu, tol_η=1e-5, verbose=true, improve_wait=10, improve_K=100, improve=true);
# that one stops early
@time sol2 = time_iteration(dm32, wk_gpu; engine=:gpu, tol_η=1e-5, verbose=true, improve_wait=10, improve_K=100, improve=true);



Expand Down
8 changes: 4 additions & 4 deletions src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ function adapt_structure(to, L::LL{G,D,F}) where G where D where F
LL(L.grid, adapt(to, L.D), adapt(to, L.φ))
end


maxabs(u::Number, v::Number) = abs(max(u,v))
# should it be merged with the general definition?
# import CUDA: CuArray
# distance(x::GVector{G, A}, y::GVector{G,A}) where G where A<:CuArray = Base.mapreduce(u->maximum(u), max, x.data-y.data)
import CUDA: CuArray
distance(x::GVector{G, A}, y::GVector{G,A}) where G where A<:CuArray = Base.mapreduce(u->maximum(abs.(u)), max, x.data-y.data)

import oneAPI: oneArray
distance(x::GVector{G, A}, y::GVector{G,A}) where G where A<:oneArray = Base.mapreduce(u->maximum(u), max, x.data-y.data)
distance(x::GVector{G, A}, y::GVector{G,A}) where G where A<:oneArray = Base.mapreduce(u->maximum(abs.(u)), max, x.data-y.data)

0 comments on commit 4a6d0ba

Please sign in to comment.