diff --git a/configs/configs.jl b/configs/configs.jl index f580cc40..56568605 100644 --- a/configs/configs.jl +++ b/configs/configs.jl @@ -115,19 +115,19 @@ function get_configs() (Int64, Int64, Int64)], transpose_a = [false, true], transpose_b = [false, true], - (OP_M, OP_N, OP_K) in [(8, 16, 2)], + (OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB) in [(8, 16, 2, 4, 8, 1)], N in GEMM_SIZES # XXX: Should we do non-square matrices as well? M = K = N - name = "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" + name = "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K)), base shape ($(OP_MB), $(OP_NB), $(OP_KB))" compute_type = promote_type(A_type, B_type) conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), block_shape = (M = 64, N = 64, K = 32), - operator = Operator.FPUOp{OP_M, OP_N, OP_K, compute_type, CD_type}, + operator = Operator.FPUOp{OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB, compute_type, CD_type}, global_a_layout = transpose_a ? Layout.UnsafeAlignedRowMajor{A_type} : Layout.UnsafeAlignedColMajor{A_type}, global_b_layout = transpose_b ? Layout.UnsafeAlignedRowMajor{B_type} : Layout.UnsafeAlignedColMajor{B_type}, @@ -159,24 +159,33 @@ function get_configs() (Float32, Float32, Float32)], transpose_a = [false, true], transpose_b = [false, true], - (OP_M, OP_N, OP_K) in [ - (4, 8, 1), - (8, 8, 1), - (4, 16, 1), - (4, 8, 2), - (8, 16, 2)], + (OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB) in vcat( + # First, test some shapes with the default base shape (4, 8, 1). + map(tup -> (tup..., 4, 8, 1), + [( 4, 8, 1), + ( 8, 8, 1), + ( 4, 16, 1), + ( 4, 8, 2), + ( 8, 16, 2)]), + # Then, test some different combinations of op shape + base shape. + [(4, 32, 1, 1, 32, 1), + (4, 32, 1, 2, 16, 1), + (16, 16, 1, 4, 8, 1), + (16, 16, 1, 8, 4, 1), + (32, 4, 1, 16, 2, 1), + (32, 4, 1, 32, 1, 1)]), N in [128] # We'll only test square matrices. M = K = N - name = "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" + name = "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K)), base shape ($(OP_MB), $(OP_NB), $(OP_KB))" compute_type = promote_type(A_type, B_type) conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), block_shape = (M = 128, N = 64, K = 32), - operator = Operator.FPUOp{OP_M, OP_N, OP_K, compute_type, CD_type}, + operator = Operator.FPUOp{OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB, compute_type, CD_type}, global_a_layout = transpose_a ? Layout.UnsafeAlignedRowMajor{A_type} : Layout.UnsafeAlignedColMajor{A_type}, global_b_layout = transpose_b ? Layout.UnsafeAlignedRowMajor{B_type} : Layout.UnsafeAlignedColMajor{B_type}, @@ -208,20 +217,20 @@ function get_configs() (Float32, Float32, Float32, 128)], transpose_a = [false, true], transpose_b = [false, true], - (OP_M, OP_N, OP_K) in [(8, 16, 2)], + (OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB) in [(8, 16, 2, 4, 8, 1)], (M, N, K) in min_dimension .* [ [1, 1, 1], [2, 2, 1], [1, 1, 2], [2, 2, 2]] - name = "Tropical GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" + name = "Tropical GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K)), base shape ($(OP_MB), $(OP_NB), $(OP_KB))" compute_type = promote_type(A_type, B_type) conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), block_shape = (M = 64, N = 64, K = 32), - operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, compute_type, CD_type}, + operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB, compute_type, CD_type}, global_a_layout = transpose_a ? Layout.UnsafeAlignedRowMajor{A_type} : Layout.UnsafeAlignedColMajor{A_type}, global_b_layout = transpose_b ? Layout.UnsafeAlignedRowMajor{B_type} : Layout.UnsafeAlignedColMajor{B_type}, diff --git a/examples/simple_matmul.jl b/examples/simple_matmul.jl index 43f1d497..c6a9df04 100644 --- a/examples/simple_matmul.jl +++ b/examples/simple_matmul.jl @@ -41,7 +41,7 @@ function main() conf = GemmKernels.get_config(; gemm_shape = (; M, N, K), block_shape, - operator = Operator.FPUOp{8, 8, 1, compute_type, eltype(C)}, + operator = Operator.FPUOp{8, 8, 1, 4, 8, 1, compute_type, eltype(C)}, global_a_layout, global_b_layout, global_c_layout, global_d_layout, shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout, diff --git a/src/config.jl b/src/config.jl index 711f3a05..b8dda0db 100644 --- a/src/config.jl +++ b/src/config.jl @@ -35,6 +35,12 @@ is_b_col_major end +struct ConfigError <: Exception + message::String +end + +Base.showerror(io::IO, e::ConfigError) = print(io, "ConfigError: ", e.message) + function heuristic_block_shape(shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout) # Determining the tile size of each block is a little trickier. # We apply the following heuristics: @@ -87,6 +93,27 @@ function adjacent_elements(num, parent_size, is_col_major) return typeof(parent_size)(t) end +function handle_operator_config(operator) + op_shape = Operator.base_shape(operator) + + # The 32 threads in a warp must at least handle one element of the operator. + if op_shape.M * op_shape.N < 32 + throw(ConfigError("The operator shape is too small. The dimensions of the operator shape must adhere to the following constraint: OPERATOR_M * OPERATOR_N ≥ 32.")) + end + + if op_shape.mb * op_shape.nb != 32 + throw(ConfigError("The base FPU operator shape should adhere to the following constraint: OPERATOR_M_BASE * OPERATOR_N_BASE = 32.")) + end + + if op_shape.kb != 1 + throw(ConfigError("The base FPU operator shape should adhere to the following constraint: OPERATOR_K_BASE = 1.")) + end + + if any((op_shape.M, op_shape.N, op_shape.K) .% (op_shape.mb, op_shape.nb, op_shape.kb) .!= 0) + throw(ConfigError("The operator shape should adhere to the following constraint: OPERATOR_M, OPERATOR_N, OPERATOR_K are multiples of OPERATOR_M_BASE, OPERATOR_N_BASE, OPERATOR_K_BASE, respectively.")) + end +end + function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kwargs...) params = Dict(kwargs) @@ -110,11 +137,38 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw block_shape = get(params, :block_shape, heuristic_block_shape(shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout)) - # 8 warps in a 4 x 2 arrangement usually works well - warps_per_block = get(params, :warps_per_block, 8) + if block_shape.M * block_shape.K < 128 || block_shape.K * block_shape.N < 128 || block_shape.K < 8 + throw(ConfigError("The block shape is too small. The dimensions of the block shape must adhere to the following constraints: BLOCK_M * BLOCK_K ≥ 128, BLOCK_K * BLOCK_N ≥ 128, BLOCK_K ≥ 8.")) + end + op_shape = Operator.shape(operator) - compute_warp = get(params, :compute_warp, - (M = block_shape.M ÷ 4, N = block_shape.N ÷ 2, K = op_shape.K)) + + if block_shape.M < 2 * op_shape.M || block_shape.N < 2 * op_shape.N + # TODO: Find out why this is. + throw(ConfigError("There is a mismatch between the block shape and the operator shape. Their dimensions must adhere to the following constraints: BLOCK_M ≥ 2 * OPERATOR_M, BLOCK_N ≥ 2 * OPERATOR_N.")) + end + + if operator <: Operator.GeneralFPUOp + handle_operator_config(operator) + end + + # 8 warps in a 4 x 2 arrangement usually works well + warps_per_block_default = 8 + compute_warp_default = (M = block_shape.M ÷ 4, N = block_shape.N ÷ 2, K = op_shape.K) + + # Best effort to make sure that the compute warp shape is not smaller than the operator shape. + if (block_shape.M ÷ op_shape.M) < 4 || (block_shape.N ÷ op_shape.N) < 2 + compute_warp_default = ( + M = block_shape.M ÷ min(block_shape.M ÷ op_shape.M, 4), + N = block_shape.N ÷ min(block_shape.N ÷ op_shape.N, 2), + K = op_shape.K + ) + warps_per_block_default = min(block_shape.M ÷ op_shape.M, 4) * min(block_shape.N ÷ op_shape.N, 2) + end + + warps_per_block = get(params, :warps_per_block, warps_per_block_default) + compute_warp = get(params, :compute_warp, compute_warp_default) + # Is the layout col-major or not? This is needed to find good values for mem_a_warp, mem_b_warp, etc. # TODO: Let the layouts handle this? diff --git a/src/matmul.jl b/src/matmul.jl index 83ab2906..adbf1040 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -155,7 +155,7 @@ end else get_config(; gemm_shape = (M = m, N = n, K = k), block_shape, - operator = Operator.FPUOp{8, 8, 1, compute_type, eltype(C)}, + operator = Operator.FPUOp{8, 8, 1, 4, 8, 1, compute_type, eltype(C)}, global_a_layout, global_b_layout, global_c_layout, global_d_layout, shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout, diff --git a/src/operator.jl b/src/operator.jl index a8518c64..d27a969f 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -22,9 +22,12 @@ end # CT is the compute type used to perform scalar operations in. # AT is the accumulator type used to accumulate partial results. -abstract type GeneralFPUOp{M, N, K, CT, AT} end +# mb, nb, kb are the base operator shapes (kb must be equal to 1 for now). +# M, N, K must be multiples of mb, nb, and kb respectively. +abstract type GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT} end -@inline shape(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K) +@inline shape(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}) where {M, N, K, mb, nb, kb, CT, AT} = (M = M, N = N, K = K) +@inline base_shape(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}) where {M, N, K, mb, nb, kb, CT, AT} = (M = M, N = N, K = K, mb = mb, nb = nb, kb = kb) for (layout_type, convert_index_func) in [ (Layout.ColMajor, identity), @@ -33,100 +36,105 @@ for (layout_type, convert_index_func) in [ (Layout.UnsafeAlignedRowMajor, x -> reverse(Tuple(x))), ] @eval begin - @inline fragtype_a(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, AT, DT} = NTuple{M * K ÷ 4, CT} - @inline fragtype_b(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, AT, DT} = NTuple{K * N ÷ 8, CT} + @inline function fragtype_a(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, mb, nb, kb, CT, AT, DT} + return NTuple{M * K ÷ mb, CT} + end + @inline function fragtype_b(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, mb, nb, kb, CT, AT, DT} + return NTuple{K * N ÷ nb, CT} + end - @inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, AT, DT} + @inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, mb, nb, kb, CT, AT, DT} return NTuple{M * N ÷ 32, AT} end - @inline function load_a(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, AT, DT} + @inline function load_a(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, mb, nb, kb, CT, AT, DT} laneId = (threadIdx().x - 1) % 32 + 1 - op_y = (laneId - 1) % 4 + 1 + op_y = (laneId - 1) % mb + 1 y, x = (tile.base.M + tile.offset.M + op_y, tile.base.K + tile.offset.K + 1) - frag = LocalArray{Tuple{M ÷ 4, K}, CT}(undef) - @loopinfo unroll for m = 1 : M ÷ 4 + frag = LocalArray{Tuple{M ÷ mb, K}, CT}(undef) + @loopinfo unroll for m = 1 : M ÷ mb @loopinfo unroll for k = 1 : K - y_layout, x_layout = $convert_index_func((y + 4 * (m - 1), x + (k - 1))) + y_layout, x_layout = $convert_index_func((y + mb * (m - 1), x + (k - 1))) @inbounds @immutable frag[m,k] = workspace[y_layout, x_layout] end end - return NTuple{M * K ÷ 4, CT}(frag) + return NTuple{M * K ÷ mb, CT}(frag) end - @inline function load_b(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, AT, DT} + @inline function load_b(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, mb, nb, kb, CT, AT, DT} laneId = (threadIdx().x - 1) % 32 + 1 - op_x = (laneId - 1) ÷ 4 + 1 + op_x = (laneId - 1) ÷ mb + 1 y, x = (tile.base.K + tile.offset.K + 1, tile.base.N + tile.offset.N + op_x) - frag = LocalArray{Tuple{K, N ÷ 8}, CT}(undef) - @loopinfo unroll for n = 1 : N ÷ 8 + frag = LocalArray{Tuple{K, N ÷ nb}, CT}(undef) + @loopinfo unroll for n = 1 : N ÷ nb @loopinfo unroll for k = 1 : K - y_layout, x_layout = $convert_index_func((y + (k - 1), x + 8 * (n - 1))) + y_layout, x_layout = $convert_index_func((y + (k - 1), x + nb * (n - 1))) @inbounds @immutable frag[k,n] = workspace[y_layout, x_layout] end end - return NTuple{K * N ÷ 8, CT}(frag) + return NTuple{K * N ÷ nb, CT}(frag) end - @inline function load_c(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, AT, DT} + @inline function load_c(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, mb, nb, kb, CT, AT, DT} laneId = (threadIdx().x - 1) % 32 + 1 - op_y = (laneId - 1) % 4 + 1 - op_x = (laneId - 1) ÷ 4 + 1 + op_y = (laneId - 1) % mb + 1 + op_x = (laneId - 1) ÷ mb + 1 y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x) - frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, AT}(undef) - @loopinfo unroll for m = 1 : M ÷ 4 - @loopinfo unroll for n = 1 : N ÷ 8 - @inbounds @immutable frag[m,n] = workspace[y + 4 * (m - 1), x + 8 * (n - 1)] + + frag = LocalArray{Tuple{M ÷ mb, N ÷ nb}, AT}(undef) + @loopinfo unroll for m = 1 : M ÷ mb + @loopinfo unroll for n = 1 : N ÷ nb + @inbounds @immutable frag[m,n] = workspace[y + mb * (m - 1), x + nb * (n - 1)] end end return NTuple{M * N ÷ 32, AT}(frag) end - @inline function store_d(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT, DT} + @inline function store_d(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, mb, nb, kb, CT, AT, DT} laneId = (threadIdx().x - 1) % 32 + 1 - op_y = (laneId - 1) % 4 + 1 - op_x = (laneId - 1) ÷ 4 + 1 + op_y = (laneId - 1) % mb + 1 + op_x = (laneId - 1) ÷ mb + 1 y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x) - frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, AT}(frag) - @loopinfo unroll for m = 1 : M ÷ 4 - @loopinfo unroll for n = 1 : N ÷ 8 - @inbounds workspace[y + 4 * (m - 1), x + 8 * (n - 1)] = frag[m, n] + frag = LocalArray{Tuple{M ÷ mb, N ÷ nb}, AT}(frag) + @loopinfo unroll for m = 1 : M ÷ mb + @loopinfo unroll for n = 1 : N ÷ nb + @inbounds workspace[y + mb * (m - 1), x + nb * (n - 1)] = frag[m,n] end end end end end -abstract type FPUOp{M, N, K, CT, AT} <: GeneralFPUOp{M, N, K, CT, AT} end -function operator_fma(::Type{FPUOp{M, N, K, CT, AT}}, a::CT, b::CT, c::AT) where {M, N, K, CT, AT} +abstract type FPUOp{M, N, K, mb, nb, kb, CT, AT} <: GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT} end +function operator_fma(::Type{FPUOp{M, N, K, mb, nb, kb, CT, AT}}, a::CT, b::CT, c::AT) where {M, N, K, mb, nb, kb, CT, AT} return fma(a, b, c) end -abstract type TropicalFPUOp{M, N, K, CT, AT} <: GeneralFPUOp{M, N, K, CT, AT} end +abstract type TropicalFPUOp{M, N, K, CT, AT} <: GeneralFPUOp{M, N, K, 4, 8, 1, CT, AT} end function operator_fma(::Type{TropicalFPUOp{M, N, K, CT, AT}}, a::CT, b::CT, c::AT) where {M, N, K, CT, AT} return max(a + b, c) end -@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT} - a_frag = LocalArray{Tuple{M ÷ 4, K}, CT}(a_frag) - b_frag = LocalArray{Tuple{K, N ÷ 8}, CT}(b_frag) - c_frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, AT}(c_frag) +@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, mb, nb, kb, CT, AT} + a_frag = LocalArray{Tuple{M ÷ mb, K}, CT}(a_frag) + b_frag = LocalArray{Tuple{K, N ÷ nb}, CT}(b_frag) + c_frag = LocalArray{Tuple{M ÷ mb, N ÷ nb}, AT}(c_frag) - @loopinfo unroll for m = 1 : M ÷ 4 - @loopinfo unroll for n = 1 : N ÷ 8 + @loopinfo unroll for m = 1 : M ÷ mb + @loopinfo unroll for n = 1 : N ÷ nb @loopinfo unroll for k = 1 : K @inbounds @immutable c_frag[m,n] = operator_fma(operator_type, a_frag[m, k], b_frag[k, n], c_frag[m, n]) end