Skip to content

Commit

Permalink
Add GPU support for extruded 1D spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Feb 4, 2025
1 parent 1d9f991 commit ada7571
Show file tree
Hide file tree
Showing 20 changed files with 654 additions and 284 deletions.
10 changes: 10 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,16 @@ steps:
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/hybrid/unit_2d.jl"
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/hybrid/convergence_2d.jl"

- label: "Unit: hyb ops 2d CUDA"
key: unit_hyb_ops_2d_cuda
command:
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/hybrid/unit_2d.jl"
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/hybrid/convergence_2d.jl"
env:
CLIMACOMMS_DEVICE: "CUDA"
agents:
slurm_gpus: 1

- label: "Unit: hyb ops 3d"
key: unit_hyb_ops_3d
command:
Expand Down
12 changes: 3 additions & 9 deletions ext/cuda/data_layouts_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@ end
function mapreduce_cuda(
f,
op,
data::Union{
DataLayouts.VF,
DataLayouts.IJFH,
DataLayouts.IJHF,
DataLayouts.VIJFH,
DataLayouts.VIJHF,
};
data::DataLayouts.AbstractData;
weighted_jacobian = OnesArray(parent(data)),
opargs...,
)
Expand Down Expand Up @@ -132,9 +126,9 @@ function mapreduce_cuda_kernel!(
gidx = _get_gidx(tidx, bidx, effective_blksize)
reduction = CUDA.CuStaticSharedArray(T, shmemsize)
reduction[tidx] = 0
(Nij, _, _, Nv, Nh) = DataLayouts.universal_size(us)
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
Nf = 1 # a view into `fidx` always gives a size of Nf = 1
nitems = Nv * Nij * Nij * Nf * Nh
nitems = Nv * Ni * Nj * Nf * Nh

# load shmem
if gidx nitems
Expand Down
24 changes: 12 additions & 12 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,15 @@ end
us::DataLayouts.UniversalSize,
n_max_threads::Integer,
)
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
Nh_thread = min(
Int(fld(n_max_threads, Nij * Nij)),
Int(fld(n_max_threads, Ni * Nj)),
maximum_allowable_threads()[3],
Nh,
)
Nh_blocks = cld(Nh, Nh_thread)
@assert prod((Nij, Nij, Nh_thread)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nh_thread))),$n_max_threads)"
return (; threads = (Nij, Nij, Nh_thread), blocks = (Nh_blocks,))
@assert prod((Ni, Nj, Nh_thread)) n_max_threads "threads,n_max_threads=($(prod((Ni, Nj, Nh_thread))),$n_max_threads)"
return (; threads = (Ni, Nj, Nh_thread), blocks = (Nh_blocks,))
end
@inline function columnwise_universal_index(us::UniversalSize)
(i, j, th) = CUDA.threadIdx()
Expand All @@ -241,9 +241,9 @@ end
n_max_threads::Integer;
Nnames,
)
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
@assert prod((Nij, Nij, Nnames)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
@assert prod((Ni, Nj, Nnames)) n_max_threads "threads,n_max_threads=($(prod((Ni, Nj, Nnames))),$n_max_threads)"
return (; threads = (Ni, Nj, Nnames), blocks = (Nh,))
end
@inline function multiple_field_solve_universal_index(us::UniversalSize)
(i, j, iname) = CUDA.threadIdx()
Expand All @@ -258,12 +258,12 @@ end
us::DataLayouts.UniversalSize,
n_max_threads::Integer = 256;
)
(Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us)
Nvthreads = min(fld(n_max_threads, Nq * Nq), maximum_allowable_threads()[3])
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
Nvthreads = min(fld(n_max_threads, Ni * Nj), maximum_allowable_threads()[3])
Nvblocks = cld(Nv, Nvthreads)
@assert prod((Nq, Nq, Nvthreads)) n_max_threads "threads,n_max_threads=($(prod((Nq, Nq, Nvthreads))),$n_max_threads)"
@assert Nq * Nq n_max_threads
return (; threads = (Nq, Nq, Nvthreads), blocks = (Nh, Nvblocks), Nvthreads)
@assert prod((Ni, Nj, Nvthreads)) n_max_threads "threads,n_max_threads=($(prod((Ni, Nj, Nvthreads))),$n_max_threads)"
@assert Ni * Nj n_max_threads
return (; threads = (Ni, Nj, Nvthreads), blocks = (Nh, Nvblocks), Nvthreads)
end
@inline function spectral_universal_index(space::Spaces.AbstractSpace)
i = threadIdx().x
Expand Down
Loading

0 comments on commit ada7571

Please sign in to comment.