Skip to content

Commit

Permalink
Fix typo in parallelise function name (#178)
Browse files Browse the repository at this point in the history
Fixes #177
  • Loading branch information
thomasfaingnaert authored Nov 21, 2023
1 parent 6a8e8cb commit 8207701
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#= Operator =#
operator # which operator to use in the inner loop

#= Is A & B stored in Column major order? This determines the iteration order of the parallellisation =#
#= Is A & B stored in Column major order? This determines the iteration order of the parallelisation =#
is_a_col_major
is_b_col_major
end
Expand Down
8 changes: 4 additions & 4 deletions src/epilogue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ struct Default end
block_tile = Tile(conf.block_shape)

# Cooperatively store a block_shape.M x block_shape.N tile of D from shared to global memory within one threadblock
@loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
@loopinfo unroll for warp_tile = parallelise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
x = @inbounds Layout.load(conf.shared_d_layout, shmem_d, thread_tile)
x = transform(x, thread_tile)
@inbounds Layout.store!(conf.global_d_layout, d, x, translate_base(thread_tile, (M = block_i, N = block_j)))
Expand Down Expand Up @@ -63,8 +63,8 @@ end
block_tile = Tile(conf.block_shape)

# Cooperatively store a block_shape.M x block_shape.N tile of D from shared to global memory within one threadblock
@loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
@loopinfo unroll for warp_tile = parallelise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
x = @inbounds Layout.load(conf.shared_d_layout, shmem_d, thread_tile)
x = apply_bias(x, ep.bias_pointer, translate_base(thread_tile, (M = block_i, N = block_j)))
x = transform(x, thread_tile)
Expand Down
58 changes: 29 additions & 29 deletions src/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d,
# (1) Cooperatively load a block_shape.M x block_shape.N tile of C from global to shared memory within one threadblock
shmem_c = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_c_layout), Layout.physical_size(conf.shared_c_layout, block_tile.MN.size))

@loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
@loopinfo unroll for warp_tile = parallelise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
x = @inbounds Layout.load(conf.global_c_layout, c, translate_base(thread_tile, (M = block_i, N = block_j)))
x = transf_gl2sh_c(x, thread_tile)
@inbounds Layout.store!(conf.shared_c_layout, shmem_c, x, thread_tile)
Expand Down Expand Up @@ -60,17 +60,17 @@ function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d,
@loopinfo unroll for block_k = 0 : block_tile.size.K : gemm_sz.size.K - 1
if Layout.threadblock_condition(conf.global_a_layout, conf.global_b_layout, block_i, block_j, block_k, block_tile)
# (3.1) Cooperatively load a block_shape.M x block_shape.K tile of A from global to shared memory within one threadblock
@loopinfo unroll for warp_tile = parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)
@loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)
@loopinfo unroll for warp_tile = parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major)
@loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major)
x = @inbounds Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_k)))
x = transf_gl2sh_a(x, thread_tile)
@inbounds Layout.store!(conf.shared_a_layout, shmem_a, x, thread_tile)
end
end

# (3.2) Cooperatively load a block_shape.K x block_shape.N tile of B from global to shared memory within one threadblock
@loopinfo unroll for warp_tile = parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)
@loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)
@loopinfo unroll for warp_tile = parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major)
@loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major)
x = @inbounds Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_k, N = block_j)))
x = transf_gl2sh_b(x, thread_tile)
@inbounds Layout.store!(conf.shared_b_layout, shmem_b, x, thread_tile)
Expand All @@ -80,7 +80,7 @@ function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d,
sync_threads()

# (3.3) Calculate a compute_warp.M x compute_warp.N tile of D, using a compute_warp.M x compute_warp.N x compute_warp.K operation
@loopinfo unroll for warp_tile = parallellise(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block)
@loopinfo unroll for warp_tile = parallelise(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block)
# (3.3.1) Load a compute_warp.M x compute_warp.K tile of A from shared memory into registers
a_frags = LocalArray{Tuple{num_fragments_m}, Operator.fragtype_a(conf.operator, conf.shared_a_layout)}(undef)

Expand Down Expand Up @@ -166,8 +166,8 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d,
# (1) Cooperatively load a block_shape.M x block_shape.N tile of C from global to shared memory within one threadblock
shmem_c = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_c_layout), Layout.physical_size(conf.shared_c_layout, block_tile.MN.size))

@loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
@loopinfo unroll for warp_tile = parallelise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallelise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
x = @inbounds Layout.load(conf.global_c_layout, c, translate_base(thread_tile, (M = block_i, N = block_j)))
x = transf_gl2sh_c(x, thread_tile)
@inbounds Layout.store!(conf.shared_c_layout, shmem_c, x, thread_tile)
Expand Down Expand Up @@ -210,28 +210,28 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d,
warp_tile_mn = subdivide(block_tile, Tile(conf.compute_warp), warpId, conf.warps_per_block)

# ld.global(0 : block_shape.K)
@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@inbounds @immutable a_fragment[i,j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = 0)))
end
end

@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@inbounds @immutable b_fragment[i,j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = 0, N = block_j)))
end
end

# st.shared()
@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
x = transf_gl2sh_a(@inbounds(a_fragment[i, j]), thread_tile)
@inbounds Layout.store!(conf.shared_a_layout, shmem_a, x, thread_tile)
end
end

@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
x = transf_gl2sh_b(@inbounds(b_fragment[i, j]), thread_tile)
@inbounds Layout.store!(conf.shared_b_layout, shmem_b, x, thread_tile)
end
Expand All @@ -253,14 +253,14 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d,
end

# ld.global(block_shape.K : 2 * block_shape.K)
@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@inbounds @immutable a_fragment[i, j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_tile.size.K)))
end
end

@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@inbounds @immutable b_fragment[i, j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_tile.size.K, N = block_j)))
end
end
Expand All @@ -274,15 +274,15 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d,
sync_threads()

# st.shared()
@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
x = transf_gl2sh_a(@inbounds(a_fragment[i, j]), thread_tile)
@inbounds Layout.store!(conf.shared_a_layout, shmem_a, x, thread_tile)
end
end

@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
x = transf_gl2sh_b(@inbounds(b_fragment[i, j]), thread_tile)
@inbounds Layout.store!(conf.shared_b_layout, shmem_b, x, thread_tile)
end
Expand All @@ -293,14 +293,14 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d,
# avoid out of bounds access for global memory
if block_k < (gemm_sz.size.K - 2 * block_tile.size.K)
# ld.global(block_k + 2 * block_shape.K : block_k + 3 * block_shape.K)
@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.MK, Tile(conf.mem_a_warp), warpId, conf.warps_per_block, conf.is_a_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_a_thread), laneId, 32, conf.is_a_col_major))
@inbounds @immutable a_fragment[i, j] = Layout.load(conf.global_a_layout, a, translate_base(thread_tile, (M = block_i, K = block_k + 2 * block_tile.size.K)))
end
end

@loopinfo unroll for (i, warp_tile) = enumerate(parallellise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallellise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@loopinfo unroll for (i, warp_tile) = enumerate(parallelise(block_tile.KN, Tile(conf.mem_b_warp), warpId, conf.warps_per_block, conf.is_b_col_major))
@loopinfo unroll for (j, thread_tile) = enumerate(parallelise(warp_tile, Tile(conf.mem_b_thread), laneId, 32, conf.is_b_col_major))
@inbounds @immutable b_fragment[i, j] = Layout.load(conf.global_b_layout, b, translate_base(thread_tile, (K = block_k + 2 * block_tile.size.K, N = block_j)))
end
end
Expand Down
16 changes: 8 additions & 8 deletions src/tiling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ export TileIterator
A [`TileIterator`](@ref) represents an iterator over a set of [`Tile`](@ref)s.
See also: [`subdivide`](@ref), [`parallellise`](@ref).
See also: [`subdivide`](@ref), [`parallelise`](@ref).
"""
struct TileIterator{tile_size, parent_size, names, T, S, idxs, col_major}
parent::Tile{parent_size, names, T}
Expand All @@ -197,31 +197,31 @@ struct TileIterator{tile_size, parent_size, names, T, S, idxs, col_major}
end

# ----------------
# Parallellisation
# Parallelisation
# ----------------

export parallellise, subdivide
export parallelise, subdivide

"""
parallellise(tile, tiling_size, idx, size)
parallelise(tile, tiling_size, idx, size)
Split the given `tile` in subtiles of size `tiling_size` across a group of
cooperating entities (e.g. warps, threads, ...).
Unlike [`subdivide`](@ref), the `tile` need not be completely covered by
`count` tiles of size `tiling_size`. If that's not the case, the subtiles
are evenly parallellised across all cooperating entities.
are evenly parallelised across all cooperating entities.
Returns a [`TileIterator`](@ref) that iterates over the [`Tile`](@ref)s of
the calling entity.
# Arguments
- `tile`: The [`Tile`](@ref) to parallellise.
- `tile`: The [`Tile`](@ref) to parallelise.
- `tiling_size`: A `NamedTuple` indicating the size of a subtile along each dimension.
- `idx`: The identity of the calling entity.
- `count`: The number of cooperating entities.
"""
@inline function parallellise(tile::Tile{size, names, T}, tiling_size::Tile{tile_sz, names, T}, idx, idxs, col_major::Bool=true) where {names, T, size, tile_sz}
@inline function parallelise(tile::Tile{size, names, T}, tiling_size::Tile{tile_sz, names, T}, idx, idxs, col_major::Bool=true) where {names, T, size, tile_sz}
# Transpose
tile = col_major ? tile : transpose(tile)
tiling_size = col_major ? tiling_size : transpose(tiling_size)
Expand Down Expand Up @@ -253,7 +253,7 @@ Returns the [`Tile`](@ref) that the calling entity is responsible for.
- `count`: The number of cooperating entities.
"""
@inline function subdivide(tile::Tile{size, names, T}, tiling_size::Tile{tile_sz, names, T}, idx, count) where {names, T, size, tile_sz}
iter = iterate(parallellise(tile, tiling_size, idx, count))::Tuple{Tile,Any}
iter = iterate(parallelise(tile, tiling_size, idx, count))::Tuple{Tile,Any}
@boundscheck begin
iter === nothing && throw(BoundsError())
end
Expand Down
4 changes: 2 additions & 2 deletions test/tiling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ using GemmKernels.Tiling
end
end

@testcase "Parallellise" begin
@testcase "Parallelise" begin
tile_size = (M = 8, N = 4)
num_tiles = (M = 2, N = 8)
tile = Tile(M = num_tiles.M * tile_size.M, N = num_tiles.N * tile_size.N)

for i = 1 : (num_tiles.M * num_tiles.N) ÷ 2
t1, t2 = parallellise(tile, Tile(tile_size), i, (num_tiles.M * num_tiles.N) ÷ 2)
t1, t2 = parallelise(tile, Tile(tile_size), i, (num_tiles.M * num_tiles.N) ÷ 2)

@test t1.offset == (M = 0, N = 0)
@test t2.offset == (M = 0, N = 4 * tile_size.N)
Expand Down

0 comments on commit 8207701

Please sign in to comment.