diff --git a/src/config.jl b/src/config.jl index 29c26f0c..fff1f7ad 100644 --- a/src/config.jl +++ b/src/config.jl @@ -30,7 +30,7 @@ #= Operator =# operator # which operator to use in the inner loop - #= Is A & B stored in Column major order? This determines the iteration order of the parallellisation =# + #= Is A & B stored in Column major order? This determines the iteration order of the parallelisation =# is_a_col_major is_b_col_major end diff --git a/src/epilogue.jl b/src/epilogue.jl index 1ede06cd..770085fa 100644 --- a/src/epilogue.jl +++ b/src/epilogue.jl @@ -24,8 +24,8 @@ struct Default end block_tile = Tile(conf.block_shape) # Cooperatively store a block_shape.M x block_shape.N tile of D from shared to global memory within one threadblock - @loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block) - @loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32) + @loopinfo unroll for warp_tile = parallelise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block) + @loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32) x = @inbounds Layout.load(conf.shared_d_layout, shmem_d, thread_tile) x = transform(x, thread_tile) @inbounds Layout.store!(conf.global_d_layout, d, x, translate_base(thread_tile, (M = block_i, N = block_j))) @@ -63,8 +63,8 @@ end block_tile = Tile(conf.block_shape) # Cooperatively store a block_shape.M x block_shape.N tile of D from shared to global memory within one threadblock - @loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block) - @loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32) + @loopinfo unroll for warp_tile = parallelise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block) + @loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32) x = @inbounds Layout.load(conf.shared_d_layout, shmem_d, thread_tile) x = apply_bias(x, ep.bias_pointer, translate_base(thread_tile, (M = block_i, N = block_j))) x = transform(x, thread_tile) diff --git a/src/kernel.jl b/src/kernel.jl index ef6f9266..0aabfe4b 100644 --- a/src/kernel.jl +++ b/src/kernel.jl @@ -28,8 +28,8 @@ function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d, # (1) Cooperatively load a block_shape.M x block_shape.N tile of C from global to shared memory within one threadblock shmem_c = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_c_layout), Layout.physical_size(conf.shared_c_layout, block_tile.MN.size)) - @loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block) - @loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32) + @loopinfo unroll for warp_tile = parallelise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block) + @loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32) x = @inbounds Layout.load(conf.global_c_layout, c, translate_base(thread_tile, (M = block_i, N = block_j))) x = transf_gl2sh_c(x, thread_tile) @inbounds Layout.store!(conf.shared_c_layout, shmem_c, x, thread_tile) @@ -60,8 +60,8 @@ function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d, @loopinfo unroll for block_k = 0 : block_tile.size.K : gemm_sz.size.K - 1 if Layout.threadblock_condition(conf.global_a_layout, conf.global_b_layout, block_i, block_j, block_k, block_tile) # (3.1) Cooperatively load a block_shape.M x block_shape.K tile of A from global to shared memory within one threadblock - @loopinfo unroll for warp_tile = parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major) - @loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major) + @loopinfo unroll for warp_tile = parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major) + @loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major) x = @inbounds Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_k))) x = transf_gl2sh_a(x, thread_tile) @inbounds Layout.store!(conf.shared_a_layout, shmem_a, x, thread_tile) @@ -69,8 +69,8 @@ function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d, end # (3.2) Cooperatively load a block_shape.K x block_shape.N tile of B from global to shared memory within one threadblock - @loopinfo unroll for warp_tile = parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major) - @loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major) + @loopinfo unroll for warp_tile = parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major) + @loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major) x = @inbounds Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_k, N = block_j))) x = transf_gl2sh_b(x, thread_tile) @inbounds Layout.store!(conf.shared_b_layout, shmem_b, x, thread_tile) @@ -80,7 +80,7 @@ function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d, sync_threads() # (3.3) Calculate a compute_warp.M x compute_warp.N tile of D, using a compute_warp.M x compute_warp.N x compute_warp.K operation - @loopinfo unroll for warp_tile = parallellise(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block) + @loopinfo unroll for warp_tile = parallelise(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block) # (3.3.1) Load a compute_warp.M x compute_warp.K tile of A from shared memory into registers a_frags = LocalArray{Tuple{num_fragments_m}, Operator.fragtype_a(conf.operator, conf.shared_a_layout)}(undef) @@ -166,8 +166,8 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d, # (1) Cooperatively load a block_shape.M x block_shape.N tile of C from global to shared memory within one threadblock shmem_c = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_c_layout), Layout.physical_size(conf.shared_c_layout, block_tile.MN.size)) - @loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block) - @loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32) + @loopinfo unroll for warp_tile = parallelise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block) + @loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32) x = @inbounds Layout.load(conf.global_c_layout, c, translate_base(thread_tile, (M = block_i, N = block_j))) x = transf_gl2sh_c(x, thread_tile) @inbounds Layout.store!(conf.shared_c_layout, shmem_c, x, thread_tile) @@ -210,28 +210,28 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d, warp_tile_mn = subdivide(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block) # ld.global(0 : block_shape.K) - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) @inbounds @immutable a_fragment[i,j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = 0))) end end - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) @inbounds @immutable b_fragment[i,j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = 0, N = block_j))) end end # st.shared() - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) x = transf_gl2sh_a(@inbounds(a_fragment[i, j]), thread_tile) @inbounds Layout.store!(conf.shared_a_layout, shmem_a, x, thread_tile) end end - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) x = transf_gl2sh_b(@inbounds(b_fragment[i, j]), thread_tile) @inbounds Layout.store!(conf.shared_b_layout, shmem_b, x, thread_tile) end @@ -253,14 +253,14 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d, end # ld.global(block_shape.K : 2 * block_shape.K) - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) @inbounds @immutable a_fragment[i, j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_tile.size.K))) end end - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) @inbounds @immutable b_fragment[i, j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_tile.size.K, N = block_j))) end end @@ -274,15 +274,15 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d, sync_threads() # st.shared() - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) x = transf_gl2sh_a(@inbounds(a_fragment[i, j]), thread_tile) @inbounds Layout.store!(conf.shared_a_layout, shmem_a, x, thread_tile) end end - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) x = transf_gl2sh_b(@inbounds(b_fragment[i, j]), thread_tile) @inbounds Layout.store!(conf.shared_b_layout, shmem_b, x, thread_tile) end @@ -293,14 +293,14 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d, # avoid out of bounds access for global memory if block_k < (gemm_sz.size.K - 2 * block_tile.size.K) # ld.global(block_k + 2 * block_shape.K : block_k + 3 * block_shape.K) - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)) @inbounds @immutable a_fragment[i, j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_k + 2 * block_tile.size.K))) end end - @loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) - @loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) + @loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)) + @loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)) @inbounds @immutable b_fragment[i, j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_k + 2 * block_tile.size.K, N = block_j))) end end diff --git a/src/tiling.jl b/src/tiling.jl index 3bc679a5..f47f1ee2 100644 --- a/src/tiling.jl +++ b/src/tiling.jl @@ -188,7 +188,7 @@ export TileIterator A [`TileIterator`](@ref) represents an iterator over a set of [`Tile`](@ref)s. -See also: [`subdivide`](@ref), [`parallellise`](@ref). +See also: [`subdivide`](@ref), [`parallelise`](@ref). """ struct TileIterator{tile_size, parent_size, names, T, S, idxs, col_major} parent::Tile{parent_size, names, T} @@ -197,31 +197,31 @@ struct TileIterator{tile_size, parent_size, names, T, S, idxs, col_major} end # ---------------- -# Parallellisation +# Parallelisation # ---------------- -export parallellise, subdivide +export parallelise, subdivide """ - parallellise(tile, tiling_size, idx, size) + parallelise(tile, tiling_size, idx, size) Split the given `tile` in subtiles of size `tiling_size` across a group of cooperating entities (e.g. warps, threads, ...). Unlike [`subdivide`](@ref), the `tile` need not be completely covered by `count` tiles of size `tiling_size`. If that's not the case, the subtiles -are evenly parallellised across all cooperating entities. +are evenly parallelised across all cooperating entities. Returns a [`TileIterator`](@ref) that iterates over the [`Tile`](@ref)s of the calling entity. # Arguments -- `tile`: The [`Tile`](@ref) to parallellise. +- `tile`: The [`Tile`](@ref) to parallelise. - `tiling_size`: A `NamedTuple` indicating the size of a subtile along each dimension. - `idx`: The identity of the calling entity. - `count`: The number of cooperating entities. """ -@inline function parallellise(tile::Tile{size, names, T}, tiling_size::Tile{tile_sz, names, T}, idx, idxs, col_major::Bool=true) where {names, T, size, tile_sz} +@inline function parallelise(tile::Tile{size, names, T}, tiling_size::Tile{tile_sz, names, T}, idx, idxs, col_major::Bool=true) where {names, T, size, tile_sz} # Transpose tile = col_major ? tile : transpose(tile) tiling_size = col_major ? tiling_size : transpose(tiling_size) @@ -253,7 +253,7 @@ Returns the [`Tile`](@ref) that the calling entity is responsible for. - `count`: The number of cooperating entities. """ @inline function subdivide(tile::Tile{size, names, T}, tiling_size::Tile{tile_sz, names, T}, idx, count) where {names, T, size, tile_sz} - iter = iterate(parallellise(tile, tiling_size, idx, count))::Tuple{Tile,Any} + iter = iterate(parallelise(tile, tiling_size, idx, count))::Tuple{Tile,Any} @boundscheck begin iter === nothing && throw(BoundsError()) end diff --git a/test/tiling.jl b/test/tiling.jl index 0b101f6c..c48b96e2 100644 --- a/test/tiling.jl +++ b/test/tiling.jl @@ -59,13 +59,13 @@ using GemmKernels.Tiling end end - @testcase "Parallellise" begin + @testcase "Parallelise" begin tile_size = (M = 8, N = 4) num_tiles = (M = 2, N = 8) tile = Tile(M = num_tiles.M * tile_size.M, N = num_tiles.N * tile_size.N) for i = 1 : (num_tiles.M * num_tiles.N) ÷ 2 - t1, t2 = parallellise(tile, Tile(tile_size), i, (num_tiles.M * num_tiles.N) ÷ 2) + t1, t2 = parallelise(tile, Tile(tile_size), i, (num_tiles.M * num_tiles.N) ÷ 2) @test t1.offset == (M = 0, N = 0) @test t2.offset == (M = 0, N = 4 * tile_size.N)