Skip to content

Commit

Permalink
Add support for CTA swizzling (#195)
Browse files Browse the repository at this point in the history
Apply a swizzling function to the mapping between tiles of the D output
matrix and the CTA ID. The goal is to maximise the probability that CTAs
that access the same tile of A/B are scheduled on neighbouring SMs at
the same time, thereby increasing L2 hit rate.
  • Loading branch information
thomasfaingnaert authored May 6, 2024
1 parent 4339926 commit d3be41c
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/GemmKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include("utils.jl")
# framework
include("layout.jl")
include("operator.jl")
include("cta-swizzle.jl")
include("config.jl")
include("epilogue.jl")
include("kernel.jl")
Expand Down
13 changes: 13 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using GemmKernels: CTASwizzle

@staticdef struct Config
#= Params =#
matmul_shape # MNK, overall shape of the MATMUL operation
Expand Down Expand Up @@ -33,6 +35,9 @@
#= Is A & B stored in Column major order? This determines the iteration order of the parallelisation =#
is_a_col_major
is_b_col_major

#= CTA Swizzling function =#
cta_swizzle
end

function Base.show(io::IO, config::Config)
Expand Down Expand Up @@ -66,6 +71,8 @@ function Base.show(io::IO, config::Config)

println(io, "is_a_col_major: $(config.is_a_col_major)")
println(io, "is_b_col_major: $(config.is_b_col_major)")

println(io, "cta_swizzle: $(config.cta_swizzle)")
end

struct ConfigError <: Exception
Expand Down Expand Up @@ -267,6 +274,9 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
require_tile_sized_global(global_c_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!")
require_tile_sized_global(global_d_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!")

# CTA swizzling function.
cta_swizzle = CTASwizzle.HorizontallyTiled{8}

return Config(
#= Params =#
gemm_shape,
Expand Down Expand Up @@ -298,5 +308,8 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
#= Is A & B Col Major? =#
is_a_col_major,
is_b_col_major,

#= CTA Swizzle function =#
cta_swizzle,
)
end
162 changes: 162 additions & 0 deletions src/cta-swizzle.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
export CTASwizzle
module CTASwizzle

# CTA Swizzling functions to improve L2 hit rate.

# Debugging
function visualise_cta_swizzle(swizzle_type, block_shape = (M = 128, N = 128), matmul_shape = (M = 1024, N = 1024))
res = fill("", (cld(matmul_shape.M, block_shape.M),
cld(matmul_shape.N, block_shape.N)))

BX, BY = number_of_blocks(swizzle_type, block_shape, matmul_shape)

for bx = 1:BX, by = 1:BY
block_i, block_j = cta_swizzle(swizzle_type, (x = bx, y = by), block_shape)
res[block_i ÷ block_shape.M + 1, block_j ÷ block_shape.N + 1] = "($bx, $by)"
end

res
end

# ----------------------------
# Identity swizzling operation
# ----------------------------
export Identity

struct Identity end

@inline function number_of_blocks(::Type{Identity}, block_shape, matmul_shape)
(cld(matmul_shape.M, block_shape.M),
cld(matmul_shape.N, block_shape.N))
end

@inline function cta_swizzle(::Type{Identity}, blockIdx, block_shape)
bx = blockIdx.x - 1
by = blockIdx.y - 1

block_i = bx
block_j = by

block_i * block_shape.M, block_j * block_shape.N
end

# ----------------------------
# Horizontally tiled swizzling
# ----------------------------

# Example: TileSize = 4:
#
# "(1, 1)" "(2, 1)" "(3, 1)" "(4, 1)" || "(1, 2)" "(2, 2)" "(3, 2)" "(4, 2)"
# =======================================================================================
# "(5, 1)" "(6, 1)" "(7, 1)" "(8, 1)" || "(5, 2)" "(6, 2)" "(7, 2)" "(8, 2)"
# =======================================================================================
# "(9, 1)" "(10, 1)" "(11, 1)" "(12, 1)" || "(9, 2)" "(10, 2)" "(11, 2)" "(12, 2)"
# =======================================================================================
# "(13, 1)" "(14, 1)" "(15, 1)" "(16, 1)" || "(13, 2)" "(14, 2)" "(15, 2)" "(16, 2)"
# =======================================================================================
# "(17, 1)" "(18, 1)" "(19, 1)" "(20, 1)" || "(17, 2)" "(18, 2)" "(19, 2)" "(20, 2)"
# =======================================================================================
# "(21, 1)" "(22, 1)" "(23, 1)" "(24, 1)" || "(21, 2)" "(22, 2)" "(23, 2)" "(24, 2)"
# =======================================================================================
# "(25, 1)" "(26, 1)" "(27, 1)" "(28, 1)" || "(25, 2)" "(26, 2)" "(27, 2)" "(28, 2)"
# =======================================================================================
# "(29, 1)" "(30, 1)" "(31, 1)" "(32, 1)" || "(29, 2)" "(30, 2)" "(31, 2)" "(32, 2)"

export HorizontallyTiled

struct HorizontallyTiled{TileSize} end

@inline function number_of_blocks(::Type{HorizontallyTiled{TileSize}}, block_shape, matmul_shape) where {TileSize}
(cld(matmul_shape.M * TileSize, block_shape.M),
cld(matmul_shape.N, block_shape.N * TileSize))
end

@inline function cta_swizzle(::Type{HorizontallyTiled{TileSize}}, blockIdx, block_shape) where {TileSize}
bx = blockIdx.x - 1
by = blockIdx.y - 1

block_i = bx ÷ TileSize
block_j = (by * TileSize) + (bx % TileSize)

block_i * block_shape.M, block_j * block_shape.N
end

# --------------------------
# Vertically tiled swizzling
# --------------------------

# Example: TileSize = 4:
#
# "(1, 1)" || "(1, 5)" || "(1, 9)" || "(1, 13)" || "(1, 17)" || "(1, 21)" || "(1, 25)" || "(1, 29)"
# "(1, 2)" || "(1, 6)" || "(1, 10)" || "(1, 14)" || "(1, 18)" || "(1, 22)" || "(1, 26)" || "(1, 30)"
# "(1, 3)" || "(1, 7)" || "(1, 11)" || "(1, 15)" || "(1, 19)" || "(1, 23)" || "(1, 27)" || "(1, 31)"
# "(1, 4)" || "(1, 8)" || "(1, 12)" || "(1, 16)" || "(1, 20)" || "(1, 24)" || "(1, 28)" || "(1, 32)"
# ====================================================================================================
# "(2, 1)" || "(2, 5)" || "(2, 9)" || "(2, 13)" || "(2, 17)" || "(2, 21)" || "(2, 25)" || "(2, 29)"
# "(2, 2)" || "(2, 6)" || "(2, 10)" || "(2, 14)" || "(2, 18)" || "(2, 22)" || "(2, 26)" || "(2, 30)"
# "(2, 3)" || "(2, 7)" || "(2, 11)" || "(2, 15)" || "(2, 19)" || "(2, 23)" || "(2, 27)" || "(2, 31)"
# "(2, 4)" || "(2, 8)" || "(2, 12)" || "(2, 16)" || "(2, 20)" || "(2, 24)" || "(2, 28)" || "(2, 32)"

export VerticallyTiled

struct VerticallyTiled{TileSize} end

@inline function number_of_blocks(::Type{VerticallyTiled{TileSize}}, block_shape, matmul_shape) where {TileSize}
(cld(matmul_shape.M, block_shape.M * TileSize),
cld(matmul_shape.N * TileSize, block_shape.N))
end

@inline function cta_swizzle(::Type{VerticallyTiled{TileSize}}, blockIdx, block_shape) where {TileSize}
bx = blockIdx.x - 1
by = blockIdx.y - 1

block_i = (bx * TileSize) + (by % TileSize)
block_j = by ÷ TileSize

block_i * block_shape.M, block_j * block_shape.N
end

# ----------------------------
# Lebesgue space-filling curve
# ----------------------------

# Example:
# "(1, 1)" "(3, 1)" "(9, 1)" "(11, 1)" "(33, 1)" "(35, 1)" "(41, 1)" "(43, 1)"
# "(2, 1)" "(4, 1)" "(10, 1)" "(12, 1)" "(34, 1)" "(36, 1)" "(42, 1)" "(44, 1)"
# "(5, 1)" "(7, 1)" "(13, 1)" "(15, 1)" "(37, 1)" "(39, 1)" "(45, 1)" "(47, 1)"
# "(6, 1)" "(8, 1)" "(14, 1)" "(16, 1)" "(38, 1)" "(40, 1)" "(46, 1)" "(48, 1)"
# "(17, 1)" "(19, 1)" "(25, 1)" "(27, 1)" "(49, 1)" "(51, 1)" "(57, 1)" "(59, 1)"
# "(18, 1)" "(20, 1)" "(26, 1)" "(28, 1)" "(50, 1)" "(52, 1)" "(58, 1)" "(60, 1)"
# "(21, 1)" "(23, 1)" "(29, 1)" "(31, 1)" "(53, 1)" "(55, 1)" "(61, 1)" "(63, 1)"
# "(22, 1)" "(24, 1)" "(30, 1)" "(32, 1)" "(54, 1)" "(56, 1)" "(62, 1)" "(64, 1)"

@inline function extract_even_bits(x)
x = x & 0x55555555
x = (x | (x >> 1)) & 0x33333333
x = (x | (x >> 2)) & 0x0F0F0F0F
x = (x | (x >> 4)) & 0x00FF00FF
x = (x | (x >> 8)) & 0x0000FFFF

x
end

@inline extract_odd_bits(x) = extract_even_bits(x >> 1)

export LebesgueCurve

struct LebesgueCurve end

@inline function number_of_blocks(::Type{LebesgueCurve}, block_shape, matmul_shape)
(cld(matmul_shape.M, block_shape.M) * cld(matmul_shape.N, block_shape.N), 1)
end

@inline function cta_swizzle(::Type{LebesgueCurve}, blockIdx, block_shape)
bx = blockIdx.x - 1

block_i = extract_even_bits(bx)
block_j = extract_odd_bits(bx)

block_i * block_shape.M, block_j * block_shape.N
end

end
7 changes: 3 additions & 4 deletions src/epilogue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Epilogue

using CUDA
using GemmKernels
using GemmKernels: CTASwizzle
using GemmKernels.Tiling
using LLVMLoopInfo: @loopinfo

Expand All @@ -14,8 +15,7 @@ struct Default end

@inline function (ep::Default)(conf::GemmKernels.Config, d, shmem_d, transform)
# Constants
block_i = (blockIdx().x - 1) * conf.block_shape.M
block_j = (blockIdx().y - 1) * conf.block_shape.N
block_i, block_j = CTASwizzle.cta_swizzle(conf.cta_swizzle, blockIdx(), conf.block_shape)

warpId = (threadIdx().x - 1) ÷ 32 + 1
laneId = (threadIdx().x - 1) % 32 + 1
Expand Down Expand Up @@ -53,8 +53,7 @@ end

@inline function (ep::Bias{B})(conf::GemmKernels.Config, d, shmem_d, transform) where {B}
# Constants
block_i = (blockIdx().x - 1) * conf.block_shape.M
block_j = (blockIdx().y - 1) * conf.block_shape.N
block_i, block_j = CTASwizzle.cta_swizzle(conf.cta_swizzle, blockIdx(), conf.block_shape)

warpId = (threadIdx().x - 1) ÷ 32 + 1
laneId = (threadIdx().x - 1) % 32 + 1
Expand Down
8 changes: 3 additions & 5 deletions src/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Kernel
using CUDA
using GemmKernels
using GemmKernels.Tiling
using GemmKernels: LocalArray, @immutable
using GemmKernels: LocalArray, @immutable, CTASwizzle
using LLVMLoopInfo: @loopinfo

function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d,
Expand All @@ -16,8 +16,7 @@ function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d,
num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N

# Constants
block_i = (blockIdx().x - 1) * conf.block_shape.M
block_j = (blockIdx().y - 1) * conf.block_shape.N
block_i, block_j = CTASwizzle.cta_swizzle(conf.cta_swizzle, blockIdx(), conf.block_shape)

warpId = (threadIdx().x - 1) ÷ 32 + 1
laneId = (threadIdx().x - 1) % 32 + 1
Expand Down Expand Up @@ -154,8 +153,7 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d,
num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N

# Constants
block_i = (blockIdx().x - 1) * conf.block_shape.M
block_j = (blockIdx().y - 1) * conf.block_shape.N
block_i, block_j = CTASwizzle.cta_swizzle(conf.cta_swizzle, blockIdx(), conf.block_shape)

warpId = (threadIdx().x - 1) ÷ 32 + 1
laneId = (threadIdx().x - 1) % 32 + 1
Expand Down
5 changes: 3 additions & 2 deletions src/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using GemmKernels: CTASwizzle

#
# low-level
#
Expand All @@ -20,8 +22,7 @@ function matmul(conf::Config, a, b, c, d;
epilogue]

threads = conf.warps_per_block * 32
blocks = (cld(conf.matmul_shape.M, conf.block_shape.M),
cld(conf.matmul_shape.N, conf.block_shape.N))
blocks = CTASwizzle.number_of_blocks(conf.cta_swizzle, conf.block_shape, conf.matmul_shape)

shmem = Kernel.shmem_size(conf, kernel)
max_shmem = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN)
Expand Down

0 comments on commit d3be41c

Please sign in to comment.