diff --git a/ext/cuda/data_layouts_copyto.jl b/ext/cuda/data_layouts_copyto.jl index 8f14df0bc4..dac1519c5c 100644 --- a/ext/cuda/data_layouts_copyto.jl +++ b/ext/cuda/data_layouts_copyto.jl @@ -1,5 +1,90 @@ DataLayouts.device_dispatch(x::CUDA.CuArray) = ToCUDA() +##### Multi-dimensional launch configuration kernels + +function knl_copyto!(dest, src) + + i = CUDA.threadIdx().x + j = CUDA.threadIdx().y + + h = CUDA.blockIdx().x + v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z + + if v <= size(dest, 4) + I = CartesianIndex((i, j, 1, v, h)) + @inbounds dest[I] = src[I] + end + return nothing +end + +function Base.copyto!( + dest::IJFH{S, Nij, Nh}, + bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, + ::ToCUDA, +) where {S, Nij, Nh} + if Nh > 0 + auto_launch!( + knl_copyto!, + (dest, bc); + threads_s = (Nij, Nij), + blocks_s = (Nh, 1), + ) + end + return dest +end + +function Base.copyto!( + dest::VIJFH{S, Nv, Nij, Nh}, + bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, + ::ToCUDA, +) where {S, Nv, Nij, Nh} + if Nv > 0 && Nh > 0 + Nv_per_block = min(Nv, fld(256, Nij * Nij)) + Nv_blocks = cld(Nv, Nv_per_block) + auto_launch!( + knl_copyto!, + (dest, bc); + threads_s = (Nij, Nij, Nv_per_block), + blocks_s = (Nh, Nv_blocks), + ) + end + return dest +end + +import ClimaCore.DataLayouts: isascalar +function knl_copyto_flat!(dest::AbstractData, bc, us) + @inbounds begin + tidx = thread_index() + if tidx ≤ get_N(us) + n = size(dest) + I = kernel_indexes(tidx, n) + dest[I] = bc[I] + end + end + return nothing +end + +function cuda_copyto!(dest::AbstractData, bc) + (_, _, Nv, _, Nh) = DataLayouts.universal_size(dest) + us = DataLayouts.UniversalSize(dest) + if Nv > 0 && Nh > 0 + nitems = prod(DataLayouts.universal_size(dest)) + auto_launch!(knl_copyto_flat!, (dest, bc, us), nitems; auto = true) + end + return dest +end +Base.copyto!( + dest::IFH{S, Ni, Nh}, + bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, + ::ToCUDA, +) where {S, Ni, Nh} = cuda_copyto!(dest, bc) +Base.copyto!( + dest::VIFH{S, Nv, Ni, Nh}, + bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, Nh}, + ::ToCUDA, +) where {S, Nv, Ni, Nh} = cuda_copyto!(dest, bc) +##### + function knl_copyto!(dest, src, us) I = universal_index(dest) if is_valid_index(dest, I, us) diff --git a/ext/cuda/matrix_fields_multiple_field_solve.jl b/ext/cuda/matrix_fields_multiple_field_solve.jl index 3955aabaa7..fc1dd462e2 100644 --- a/ext/cuda/matrix_fields_multiple_field_solve.jl +++ b/ext/cuda/matrix_fields_multiple_field_solve.jl @@ -39,7 +39,8 @@ NVTX.@annotate function multiple_field_solve!( nitems = Ni * Nj * Nh * Nnames threads = threads_via_occupancy(multiple_field_solve_kernel!, args) n_max_threads = min(threads, nitems) - p = multiple_field_solve_partition(us, n_max_threads; Nnames) + # p = multiple_field_solve_partition(us, n_max_threads; Nnames) + p = linear_partition(nitems, n_max_threads) auto_launch!( multiple_field_solve_kernel!, @@ -89,9 +90,11 @@ function multiple_field_solve_kernel!( ::Val{Nnames}, ) where {Nnames} @inbounds begin - (I, iname) = multiple_field_solve_universal_index(us) - if multiple_field_solve_is_valid_index(I, us) - (i, j, _, _, h) = I.I + (Ni, Nj, _, _, Nh) = size(Fields.field_values(x1)) + tidx = thread_index() + n = (Ni, Nj, Nh, Nnames) + if valid_range(tidx, prod(n)) + (i, j, h, iname) = kernel_indexes(tidx, n).I generated_single_field_solve!( device, caches, diff --git a/ext/cuda/matrix_fields_single_field_solve.jl b/ext/cuda/matrix_fields_single_field_solve.jl index b486ef9041..13f52929b0 100644 --- a/ext/cuda/matrix_fields_single_field_solve.jl +++ b/ext/cuda/matrix_fields_single_field_solve.jl @@ -20,7 +20,7 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b) threads = threads_via_occupancy(single_field_solve_kernel!, args) nitems = Ni * Nj * Nh n_max_threads = min(threads, nitems) - p = columnwise_partition(us, n_max_threads) + p = linear_partition(nitems, n_max_threads) auto_launch!( single_field_solve_kernel!, args; @@ -30,9 +30,10 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b) end function single_field_solve_kernel!(device, cache, x, A, b, us) - I = columnwise_universal_index(us) - if columnwise_is_valid_index(I, us) - (i, j, _, _, h) = I.I + idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + Ni, Nj, _, _, Nh = size(Fields.field_values(A)) + if idx <= Ni * Nj * Nh + (i, j, h) = CartesianIndices((1:Ni, 1:Nj, 1:Nh))[idx].I _single_field_solve!( device, Spaces.column(cache, i, j, h), diff --git a/ext/cuda/operators_integral.jl b/ext/cuda/operators_integral.jl index 7b344b511e..e132385bf9 100644 --- a/ext/cuda/operators_integral.jl +++ b/ext/cuda/operators_integral.jl @@ -33,7 +33,7 @@ function column_reduce_device!( nitems = Ni * Nj * Nh threads = threads_via_occupancy(bycolumn_kernel!, args) n_max_threads = min(threads, nitems) - p = columnwise_partition(us, n_max_threads) + p = linear_partition(nitems, n_max_threads) auto_launch!( bycolumn_kernel!, args; @@ -67,7 +67,7 @@ function column_accumulate_device!( nitems = Ni * Nj * Nh threads = threads_via_occupancy(bycolumn_kernel!, args) n_max_threads = min(threads, nitems) - p = columnwise_partition(us, n_max_threads) + p = linear_partition(nitems, n_max_threads) auto_launch!( bycolumn_kernel!, args; @@ -89,9 +89,10 @@ bycolumn_kernel!( if space isa Spaces.FiniteDifferenceSpace single_column_function!(f, transform, output, input, init, space) else - I = columnwise_universal_index(us) - if columnwise_is_valid_index(I, us) - (i, j, _, _, h) = I.I + idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x + Ni, Nj, _, _, Nh = size(Fields.field_values(output)) + if idx <= Ni * Nj * Nh + i, j, h = cart_ind((Ni, Nj, Nh), idx).I single_column_function!( f, transform, diff --git a/ext/cuda/operators_thomas_algorithm.jl b/ext/cuda/operators_thomas_algorithm.jl index d25e26fd3d..246836966e 100644 --- a/ext/cuda/operators_thomas_algorithm.jl +++ b/ext/cuda/operators_thomas_algorithm.jl @@ -11,7 +11,7 @@ function column_thomas_solve!(::ClimaComms.CUDADevice, A, b) threads = threads_via_occupancy(thomas_algorithm_kernel!, args) nitems = Ni * Nj * Nh n_max_threads = min(threads, nitems) - p = columnwise_partition(us, n_max_threads) + p = linear_partition(nitems, n_max_threads) auto_launch!( thomas_algorithm_kernel!, args; @@ -25,9 +25,10 @@ function thomas_algorithm_kernel!( b::Fields.ExtrudedFiniteDifferenceField, us::DataLayouts.UniversalSize, ) - I = columnwise_universal_index(us) - if columnwise_is_valid_index(I, us) - (i, j, _, _, h) = I.I + idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x + Ni, Nj, _, _, Nh = size(Fields.field_values(A)) + if idx <= Ni * Nj * Nh + i, j, h = cart_ind((Ni, Nj, Nh), idx).I thomas_algorithm!(Spaces.column(A, i, j, h), Spaces.column(b, i, j, h)) end return nothing