Skip to content

Commit

Permalink
Swap CUDA grid dimensions for some partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
sriharshakandala committed Dec 18, 2024
1 parent 6ad8294 commit d751a7b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ function is_valid_index end
Nv_thread = min(Int(fld(n_max_threads, Nij * Nij)), Nv)
Nv_blocks = cld(Nv, Nv_thread)
@assert prod((Nv_thread, Nij, Nij)) n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Nij, Nij))),$n_max_threads)"
return (; threads = (Nv_thread, Nij, Nij), blocks = (Nv_blocks, Nh))
return (; threads = (Nv_thread, Nij, Nij), blocks = (Nh, Nv_blocks))
end
@inline function universal_index(::Union{DataLayouts.VIJFH, DataLayouts.VIJHF})
(tv, i, j) = CUDA.threadIdx()
(bv, h) = CUDA.blockIdx()
(h, bv) = CUDA.blockIdx()
v = tv + (bv - 1) * CUDA.blockDim().x
return CartesianIndex((i, j, 1, v, h))
end
Expand Down Expand Up @@ -152,11 +152,11 @@ end
Nv_thread = min(Int(fld(n_max_threads, Ni)), Nv)
Nv_blocks = cld(Nv, Nv_thread)
@assert prod((Nv_thread, Ni)) n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Ni))),$n_max_threads)"
return (; threads = (Nv_thread, Ni), blocks = (Nv_blocks, Nh))
return (; threads = (Nv_thread, Ni), blocks = (Nh, Nv_blocks))
end
@inline function universal_index(::Union{DataLayouts.VIFH, DataLayouts.VIHF})
(tv, i) = CUDA.threadIdx()
(bv, h) = CUDA.blockIdx()
(h, bv) = CUDA.blockIdx()
v = tv + (bv - 1) * CUDA.blockDim().x
return CartesianIndex((i, 1, 1, v, h))
end
Expand Down

0 comments on commit d751a7b

Please sign in to comment.