diff --git a/src/GemmKernels.jl b/src/GemmKernels.jl index a6073d62..abea36cb 100644 --- a/src/GemmKernels.jl +++ b/src/GemmKernels.jl @@ -11,6 +11,7 @@ include("utils.jl") # framework include("layout.jl") include("operator.jl") +include("cta-swizzle.jl") include("config.jl") include("epilogue.jl") include("kernel.jl") diff --git a/src/config.jl b/src/config.jl index 85c1b690..722964f2 100644 --- a/src/config.jl +++ b/src/config.jl @@ -1,3 +1,5 @@ +using GemmKernels: CTASwizzle + @staticdef struct Config #= Params =# matmul_shape # MNK, overall shape of the MATMUL operation @@ -33,6 +35,9 @@ #= Is A & B stored in Column major order? This determines the iteration order of the parallelisation =# is_a_col_major is_b_col_major + + #= CTA Swizzling function =# + cta_swizzle end function Base.show(io::IO, config::Config) @@ -66,6 +71,8 @@ function Base.show(io::IO, config::Config) println(io, "is_a_col_major: $(config.is_a_col_major)") println(io, "is_b_col_major: $(config.is_b_col_major)") + + println(io, "cta_swizzle: $(config.cta_swizzle)") end struct ConfigError <: Exception @@ -267,6 +274,9 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw require_tile_sized_global(global_c_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!") require_tile_sized_global(global_d_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!") + # CTA swizzling function. + cta_swizzle = CTASwizzle.HorizontallyTiled{8} + return Config( #= Params =# gemm_shape, @@ -298,5 +308,8 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw #= Is A & B Col Major? =# is_a_col_major, is_b_col_major, + + #= CTA Swizzle function =# + cta_swizzle, ) end diff --git a/src/cta-swizzle.jl b/src/cta-swizzle.jl new file mode 100644 index 00000000..26ede511 --- /dev/null +++ b/src/cta-swizzle.jl @@ -0,0 +1,162 @@ +export CTASwizzle +module CTASwizzle + +# CTA Swizzling functions to improve L2 hit rate. + +# Debugging +function visualise_cta_swizzle(swizzle_type, block_shape = (M = 128, N = 128), matmul_shape = (M = 1024, N = 1024)) + res = fill("", (cld(matmul_shape.M, block_shape.M), + cld(matmul_shape.N, block_shape.N))) + + BX, BY = number_of_blocks(swizzle_type, block_shape, matmul_shape) + + for bx = 1:BX, by = 1:BY + block_i, block_j = cta_swizzle(swizzle_type, (x = bx, y = by), block_shape) + res[block_i ÷ block_shape.M + 1, block_j ÷ block_shape.N + 1] = "($bx, $by)" + end + + res +end + +# ---------------------------- +# Identity swizzling operation +# ---------------------------- +export Identity + +struct Identity end + +@inline function number_of_blocks(::Type{Identity}, block_shape, matmul_shape) + (cld(matmul_shape.M, block_shape.M), + cld(matmul_shape.N, block_shape.N)) +end + +@inline function cta_swizzle(::Type{Identity}, blockIdx, block_shape) + bx = blockIdx.x - 1 + by = blockIdx.y - 1 + + block_i = bx + block_j = by + + block_i * block_shape.M, block_j * block_shape.N +end + +# ---------------------------- +# Horizontally tiled swizzling +# ---------------------------- + +# Example: TileSize = 4: +# +# "(1, 1)" "(2, 1)" "(3, 1)" "(4, 1)" || "(1, 2)" "(2, 2)" "(3, 2)" "(4, 2)" +# ======================================================================================= +# "(5, 1)" "(6, 1)" "(7, 1)" "(8, 1)" || "(5, 2)" "(6, 2)" "(7, 2)" "(8, 2)" +# ======================================================================================= +# "(9, 1)" "(10, 1)" "(11, 1)" "(12, 1)" || "(9, 2)" "(10, 2)" "(11, 2)" "(12, 2)" +# ======================================================================================= +# "(13, 1)" "(14, 1)" "(15, 1)" "(16, 1)" || "(13, 2)" "(14, 2)" "(15, 2)" "(16, 2)" +# ======================================================================================= +# "(17, 1)" "(18, 1)" "(19, 1)" "(20, 1)" || "(17, 2)" "(18, 2)" "(19, 2)" "(20, 2)" +# ======================================================================================= +# "(21, 1)" "(22, 1)" "(23, 1)" "(24, 1)" || "(21, 2)" "(22, 2)" "(23, 2)" "(24, 2)" +# ======================================================================================= +# "(25, 1)" "(26, 1)" "(27, 1)" "(28, 1)" || "(25, 2)" "(26, 2)" "(27, 2)" "(28, 2)" +# ======================================================================================= +# "(29, 1)" "(30, 1)" "(31, 1)" "(32, 1)" || "(29, 2)" "(30, 2)" "(31, 2)" "(32, 2)" + +export HorizontallyTiled + +struct HorizontallyTiled{TileSize} end + +@inline function number_of_blocks(::Type{HorizontallyTiled{TileSize}}, block_shape, matmul_shape) where {TileSize} + (cld(matmul_shape.M * TileSize, block_shape.M), + cld(matmul_shape.N, block_shape.N * TileSize)) +end + +@inline function cta_swizzle(::Type{HorizontallyTiled{TileSize}}, blockIdx, block_shape) where {TileSize} + bx = blockIdx.x - 1 + by = blockIdx.y - 1 + + block_i = bx ÷ TileSize + block_j = (by * TileSize) + (bx % TileSize) + + block_i * block_shape.M, block_j * block_shape.N +end + +# -------------------------- +# Vertically tiled swizzling +# -------------------------- + +# Example: TileSize = 4: +# +# "(1, 1)" || "(1, 5)" || "(1, 9)" || "(1, 13)" || "(1, 17)" || "(1, 21)" || "(1, 25)" || "(1, 29)" +# "(1, 2)" || "(1, 6)" || "(1, 10)" || "(1, 14)" || "(1, 18)" || "(1, 22)" || "(1, 26)" || "(1, 30)" +# "(1, 3)" || "(1, 7)" || "(1, 11)" || "(1, 15)" || "(1, 19)" || "(1, 23)" || "(1, 27)" || "(1, 31)" +# "(1, 4)" || "(1, 8)" || "(1, 12)" || "(1, 16)" || "(1, 20)" || "(1, 24)" || "(1, 28)" || "(1, 32)" +# ==================================================================================================== +# "(2, 1)" || "(2, 5)" || "(2, 9)" || "(2, 13)" || "(2, 17)" || "(2, 21)" || "(2, 25)" || "(2, 29)" +# "(2, 2)" || "(2, 6)" || "(2, 10)" || "(2, 14)" || "(2, 18)" || "(2, 22)" || "(2, 26)" || "(2, 30)" +# "(2, 3)" || "(2, 7)" || "(2, 11)" || "(2, 15)" || "(2, 19)" || "(2, 23)" || "(2, 27)" || "(2, 31)" +# "(2, 4)" || "(2, 8)" || "(2, 12)" || "(2, 16)" || "(2, 20)" || "(2, 24)" || "(2, 28)" || "(2, 32)" + +export VerticallyTiled + +struct VerticallyTiled{TileSize} end + +@inline function number_of_blocks(::Type{VerticallyTiled{TileSize}}, block_shape, matmul_shape) where {TileSize} + (cld(matmul_shape.M, block_shape.M * TileSize), + cld(matmul_shape.N * TileSize, block_shape.N)) +end + +@inline function cta_swizzle(::Type{VerticallyTiled{TileSize}}, blockIdx, block_shape) where {TileSize} + bx = blockIdx.x - 1 + by = blockIdx.y - 1 + + block_i = (bx * TileSize) + (by % TileSize) + block_j = by ÷ TileSize + + block_i * block_shape.M, block_j * block_shape.N +end + +# ---------------------------- +# Lebesgue space-filling curve +# ---------------------------- + +# Example: +# "(1, 1)" "(3, 1)" "(9, 1)" "(11, 1)" "(33, 1)" "(35, 1)" "(41, 1)" "(43, 1)" +# "(2, 1)" "(4, 1)" "(10, 1)" "(12, 1)" "(34, 1)" "(36, 1)" "(42, 1)" "(44, 1)" +# "(5, 1)" "(7, 1)" "(13, 1)" "(15, 1)" "(37, 1)" "(39, 1)" "(45, 1)" "(47, 1)" +# "(6, 1)" "(8, 1)" "(14, 1)" "(16, 1)" "(38, 1)" "(40, 1)" "(46, 1)" "(48, 1)" +# "(17, 1)" "(19, 1)" "(25, 1)" "(27, 1)" "(49, 1)" "(51, 1)" "(57, 1)" "(59, 1)" +# "(18, 1)" "(20, 1)" "(26, 1)" "(28, 1)" "(50, 1)" "(52, 1)" "(58, 1)" "(60, 1)" +# "(21, 1)" "(23, 1)" "(29, 1)" "(31, 1)" "(53, 1)" "(55, 1)" "(61, 1)" "(63, 1)" +# "(22, 1)" "(24, 1)" "(30, 1)" "(32, 1)" "(54, 1)" "(56, 1)" "(62, 1)" "(64, 1)" + +@inline function extract_even_bits(x) + x = x & 0x55555555 + x = (x | (x >> 1)) & 0x33333333 + x = (x | (x >> 2)) & 0x0F0F0F0F + x = (x | (x >> 4)) & 0x00FF00FF + x = (x | (x >> 8)) & 0x0000FFFF + + x +end + +@inline extract_odd_bits(x) = extract_even_bits(x >> 1) + +export LebesgueCurve + +struct LebesgueCurve end + +@inline function number_of_blocks(::Type{LebesgueCurve}, block_shape, matmul_shape) + (cld(matmul_shape.M, block_shape.M) * cld(matmul_shape.N, block_shape.N), 1) +end + +@inline function cta_swizzle(::Type{LebesgueCurve}, blockIdx, block_shape) + bx = blockIdx.x - 1 + + block_i = extract_even_bits(bx) + block_j = extract_odd_bits(bx) + + block_i * block_shape.M, block_j * block_shape.N +end + +end diff --git a/src/epilogue.jl b/src/epilogue.jl index 770085fa..3efbfade 100644 --- a/src/epilogue.jl +++ b/src/epilogue.jl @@ -3,6 +3,7 @@ module Epilogue using CUDA using GemmKernels +using GemmKernels: CTASwizzle using GemmKernels.Tiling using LLVMLoopInfo: @loopinfo @@ -14,8 +15,7 @@ struct Default end @inline function (ep::Default)(conf::GemmKernels.Config, d, shmem_d, transform) # Constants - block_i = (blockIdx().x - 1) * conf.block_shape.M - block_j = (blockIdx().y - 1) * conf.block_shape.N + block_i, block_j = CTASwizzle.cta_swizzle(conf.cta_swizzle, blockIdx(), conf.block_shape) warpId = (threadIdx().x - 1) ÷ 32 + 1 laneId = (threadIdx().x - 1) % 32 + 1 @@ -53,8 +53,7 @@ end @inline function (ep::Bias{B})(conf::GemmKernels.Config, d, shmem_d, transform) where {B} # Constants - block_i = (blockIdx().x - 1) * conf.block_shape.M - block_j = (blockIdx().y - 1) * conf.block_shape.N + block_i, block_j = CTASwizzle.cta_swizzle(conf.cta_swizzle, blockIdx(), conf.block_shape) warpId = (threadIdx().x - 1) ÷ 32 + 1 laneId = (threadIdx().x - 1) % 32 + 1 diff --git a/src/kernel.jl b/src/kernel.jl index 0aabfe4b..c27d506f 100644 --- a/src/kernel.jl +++ b/src/kernel.jl @@ -4,7 +4,7 @@ module Kernel using CUDA using GemmKernels using GemmKernels.Tiling -using GemmKernels: LocalArray, @immutable +using GemmKernels: LocalArray, @immutable, CTASwizzle using LLVMLoopInfo: @loopinfo function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d, @@ -16,8 +16,7 @@ function matmul_singlestage(conf::GemmKernels.Config, a, b, c, d, num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N # Constants - block_i = (blockIdx().x - 1) * conf.block_shape.M - block_j = (blockIdx().y - 1) * conf.block_shape.N + block_i, block_j = CTASwizzle.cta_swizzle(conf.cta_swizzle, blockIdx(), conf.block_shape) warpId = (threadIdx().x - 1) ÷ 32 + 1 laneId = (threadIdx().x - 1) % 32 + 1 @@ -154,8 +153,7 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d, num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N # Constants - block_i = (blockIdx().x - 1) * conf.block_shape.M - block_j = (blockIdx().y - 1) * conf.block_shape.N + block_i, block_j = CTASwizzle.cta_swizzle(conf.cta_swizzle, blockIdx(), conf.block_shape) warpId = (threadIdx().x - 1) ÷ 32 + 1 laneId = (threadIdx().x - 1) % 32 + 1 diff --git a/src/matmul.jl b/src/matmul.jl index 667a2b7e..cbf77035 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -1,3 +1,5 @@ +using GemmKernels: CTASwizzle + # # low-level # @@ -20,8 +22,7 @@ function matmul(conf::Config, a, b, c, d; epilogue] threads = conf.warps_per_block * 32 - blocks = (cld(conf.matmul_shape.M, conf.block_shape.M), - cld(conf.matmul_shape.N, conf.block_shape.N)) + blocks = CTASwizzle.number_of_blocks(conf.cta_swizzle, conf.block_shape, conf.matmul_shape) shmem = Kernel.shmem_size(conf, kernel) max_shmem = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN)