diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index 6ff4967855..91cd0191fb 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -55,11 +55,11 @@ function is_valid_index end Nv_thread = min(Int(fld(n_max_threads, Nij * Nij)), Nv) Nv_blocks = cld(Nv, Nv_thread) @assert prod((Nv_thread, Nij, Nij)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Nij, Nij))),$n_max_threads)" - return (; threads = (Nv_thread, Nij, Nij), blocks = (Nv_blocks, Nh)) + return (; threads = (Nv_thread, Nij, Nij), blocks = (Nh, Nv_blocks)) end @inline function universal_index(::Union{DataLayouts.VIJFH, DataLayouts.VIJHF}) (tv, i, j) = CUDA.threadIdx() - (bv, h) = CUDA.blockIdx() + (h, bv) = CUDA.blockIdx() v = tv + (bv - 1) * CUDA.blockDim().x return CartesianIndex((i, j, 1, v, h)) end @@ -152,11 +152,11 @@ end Nv_thread = min(Int(fld(n_max_threads, Ni)), Nv) Nv_blocks = cld(Nv, Nv_thread) @assert prod((Nv_thread, Ni)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Ni))),$n_max_threads)" - return (; threads = (Nv_thread, Ni), blocks = (Nv_blocks, Nh)) + return (; threads = (Nv_thread, Ni), blocks = (Nh, Nv_blocks)) end @inline function universal_index(::Union{DataLayouts.VIFH, DataLayouts.VIHF}) (tv, i) = CUDA.threadIdx() - (bv, h) = CUDA.blockIdx() + (h, bv) = CUDA.blockIdx() v = tv + (bv - 1) * CUDA.blockDim().x return CartesianIndex((i, 1, 1, v, h)) end