Skip to content

Commit

Permalink
add more flexible FPU operator
Browse files Browse the repository at this point in the history
  • Loading branch information
wardvermeulen authored and thomasfaingnaert committed Oct 24, 2023
1 parent 73061d7 commit 82bc0d2
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 60 deletions.
37 changes: 23 additions & 14 deletions configs/configs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,19 @@ function get_configs()
(Int64, Int64, Int64)],
transpose_a = [false, true],
transpose_b = [false, true],
(OP_M, OP_N, OP_K) in [(8, 16, 2)],
(OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB) in [(8, 16, 2, 4, 8, 1)],
N in GEMM_SIZES

# XXX: Should we do non-square matrices as well?
M = K = N

name = "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))"
name = "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K)), base shape ($(OP_MB), $(OP_NB), $(OP_KB))"
compute_type = promote_type(A_type, B_type)

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
block_shape = (M = 64, N = 64, K = 32),
operator = Operator.FPUOp{OP_M, OP_N, OP_K, compute_type, CD_type},
operator = Operator.FPUOp{OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB, compute_type, CD_type},
global_a_layout = transpose_a ? Layout.UnsafeAlignedRowMajor{A_type} : Layout.UnsafeAlignedColMajor{A_type},
global_b_layout = transpose_b ? Layout.UnsafeAlignedRowMajor{B_type} : Layout.UnsafeAlignedColMajor{B_type},

Expand Down Expand Up @@ -159,24 +159,33 @@ function get_configs()
(Float32, Float32, Float32)],
transpose_a = [false, true],
transpose_b = [false, true],
(OP_M, OP_N, OP_K) in [
(4, 8, 1),
(8, 8, 1),
(4, 16, 1),
(4, 8, 2),
(8, 16, 2)],
(OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB) in vcat(
# First, test some shapes with the default base shape (4, 8, 1).
map(tup -> (tup..., 4, 8, 1),
[( 4, 8, 1),
( 8, 8, 1),
( 4, 16, 1),
( 4, 8, 2),
( 8, 16, 2)]),
# Then, test some different combinations of op shape + base shape.
[(4, 32, 1, 1, 32, 1),
(4, 32, 1, 2, 16, 1),
(16, 16, 1, 4, 8, 1),
(16, 16, 1, 8, 4, 1),
(32, 4, 1, 16, 2, 1),
(32, 4, 1, 32, 1, 1)]),
N in [128]

# We'll only test square matrices.
M = K = N

name = "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))"
name = "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K)), base shape ($(OP_MB), $(OP_NB), $(OP_KB))"
compute_type = promote_type(A_type, B_type)

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
block_shape = (M = 128, N = 64, K = 32),
operator = Operator.FPUOp{OP_M, OP_N, OP_K, compute_type, CD_type},
operator = Operator.FPUOp{OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB, compute_type, CD_type},
global_a_layout = transpose_a ? Layout.UnsafeAlignedRowMajor{A_type} : Layout.UnsafeAlignedColMajor{A_type},
global_b_layout = transpose_b ? Layout.UnsafeAlignedRowMajor{B_type} : Layout.UnsafeAlignedColMajor{B_type},

Expand Down Expand Up @@ -208,20 +217,20 @@ function get_configs()
(Float32, Float32, Float32, 128)],
transpose_a = [false, true],
transpose_b = [false, true],
(OP_M, OP_N, OP_K) in [(8, 16, 2)],
(OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB) in [(8, 16, 2, 4, 8, 1)],
(M, N, K) in min_dimension .* [
[1, 1, 1],
[2, 2, 1],
[1, 1, 2],
[2, 2, 2]]

name = "Tropical GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))"
name = "Tropical GEMM $(A_type)*$(B_type)=$(CD_type) ($(M)×$(K)) · ($(K)×$(N)) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K)), base shape ($(OP_MB), $(OP_NB), $(OP_KB))"
compute_type = promote_type(A_type, B_type)

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
block_shape = (M = 64, N = 64, K = 32),
operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, compute_type, CD_type},
operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB, compute_type, CD_type},
global_a_layout = transpose_a ? Layout.UnsafeAlignedRowMajor{A_type} : Layout.UnsafeAlignedColMajor{A_type},
global_b_layout = transpose_b ? Layout.UnsafeAlignedRowMajor{B_type} : Layout.UnsafeAlignedColMajor{B_type},

Expand Down
2 changes: 1 addition & 1 deletion examples/simple_matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function main()

conf = GemmKernels.get_config(;
gemm_shape = (; M, N, K), block_shape,
operator = Operator.FPUOp{8, 8, 1, compute_type, eltype(C)},
operator = Operator.FPUOp{8, 8, 1, 4, 8, 1, compute_type, eltype(C)},

global_a_layout, global_b_layout, global_c_layout, global_d_layout,
shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,
Expand Down
62 changes: 58 additions & 4 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
is_b_col_major
end

struct ConfigError <: Exception
message::String
end

Base.showerror(io::IO, e::ConfigError) = print(io, "ConfigError: ", e.message)

function heuristic_block_shape(shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout)
# Determining the tile size of each block is a little trickier.
# We apply the following heuristics:
Expand Down Expand Up @@ -87,6 +93,27 @@ function adjacent_elements(num, parent_size, is_col_major)
return typeof(parent_size)(t)
end

function handle_operator_config(operator)
op_shape = Operator.base_shape(operator)

# The 32 threads in a warp must at least handle one element of the operator.
if op_shape.M * op_shape.N < 32
throw(ConfigError("The operator shape is too small. The dimensions of the operator shape must adhere to the following constraint: OPERATOR_M * OPERATOR_N ≥ 32."))
end

if op_shape.mb * op_shape.nb != 32
throw(ConfigError("The base FPU operator shape should adhere to the following constraint: OPERATOR_M_BASE * OPERATOR_N_BASE = 32."))
end

if op_shape.kb != 1
throw(ConfigError("The base FPU operator shape should adhere to the following constraint: OPERATOR_K_BASE = 1."))
end

if any((op_shape.M, op_shape.N, op_shape.K) .% (op_shape.mb, op_shape.nb, op_shape.kb) .!= 0)
throw(ConfigError("The operator shape should adhere to the following constraint: OPERATOR_M, OPERATOR_N, OPERATOR_K are multiples of OPERATOR_M_BASE, OPERATOR_N_BASE, OPERATOR_K_BASE, respectively."))
end
end

function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kwargs...)
params = Dict(kwargs)

Expand All @@ -110,11 +137,38 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
block_shape = get(params, :block_shape,
heuristic_block_shape(shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout))

# 8 warps in a 4 x 2 arrangement usually works well
warps_per_block = get(params, :warps_per_block, 8)
if block_shape.M * block_shape.K < 128 || block_shape.K * block_shape.N < 128 || block_shape.K < 8
throw(ConfigError("The block shape is too small. The dimensions of the block shape must adhere to the following constraints: BLOCK_M * BLOCK_K ≥ 128, BLOCK_K * BLOCK_N ≥ 128, BLOCK_K ≥ 8."))
end

op_shape = Operator.shape(operator)
compute_warp = get(params, :compute_warp,
(M = block_shape.M ÷ 4, N = block_shape.N ÷ 2, K = op_shape.K))

if block_shape.M < 2 * op_shape.M || block_shape.N < 2 * op_shape.N
# TODO: Find out why this is.
throw(ConfigError("There is a mismatch between the block shape and the operator shape. Their dimensions must adhere to the following constraints: BLOCK_M ≥ 2 * OPERATOR_M, BLOCK_N ≥ 2 * OPERATOR_N."))
end

if operator <: Operator.GeneralFPUOp
handle_operator_config(operator)
end

# 8 warps in a 4 x 2 arrangement usually works well
warps_per_block_default = 8
compute_warp_default = (M = block_shape.M ÷ 4, N = block_shape.N ÷ 2, K = op_shape.K)

# Best effort to make sure that the compute warp shape is not smaller than the operator shape.
if (block_shape.M ÷ op_shape.M) < 4 || (block_shape.N ÷ op_shape.N) < 2
compute_warp_default = (
M = block_shape.M ÷ min(block_shape.M ÷ op_shape.M, 4),
N = block_shape.N ÷ min(block_shape.N ÷ op_shape.N, 2),
K = op_shape.K
)
warps_per_block_default = min(block_shape.M ÷ op_shape.M, 4) * min(block_shape.N ÷ op_shape.N, 2)
end

warps_per_block = get(params, :warps_per_block, warps_per_block_default)
compute_warp = get(params, :compute_warp, compute_warp_default)


# Is the layout col-major or not? This is needed to find good values for mem_a_warp, mem_b_warp, etc.
# TODO: Let the layouts handle this?
Expand Down
2 changes: 1 addition & 1 deletion src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ end
else
get_config(;
gemm_shape = (M = m, N = n, K = k), block_shape,
operator = Operator.FPUOp{8, 8, 1, compute_type, eltype(C)},
operator = Operator.FPUOp{8, 8, 1, 4, 8, 1, compute_type, eltype(C)},

global_a_layout, global_b_layout, global_c_layout, global_d_layout,
shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,
Expand Down
88 changes: 48 additions & 40 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ end

# CT is the compute type used to perform scalar operations in.
# AT is the accumulator type used to accumulate partial results.
abstract type GeneralFPUOp{M, N, K, CT, AT} end
# mb, nb, kb are the base operator shapes (kb must be equal to 1 for now).
# M, N, K must be multiples of mb, nb, and kb respectively.
abstract type GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT} end

@inline shape(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K)
@inline shape(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}) where {M, N, K, mb, nb, kb, CT, AT} = (M = M, N = N, K = K)
@inline base_shape(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}) where {M, N, K, mb, nb, kb, CT, AT} = (M = M, N = N, K = K, mb = mb, nb = nb, kb = kb)

for (layout_type, convert_index_func) in [
(Layout.ColMajor, identity),
Expand All @@ -33,100 +36,105 @@ for (layout_type, convert_index_func) in [
(Layout.UnsafeAlignedRowMajor, x -> reverse(Tuple(x))),
]
@eval begin
@inline fragtype_a(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, AT, DT} = NTuple{M * K ÷ 4, CT}
@inline fragtype_b(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, AT, DT} = NTuple{K * N ÷ 8, CT}
@inline function fragtype_a(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, mb, nb, kb, CT, AT, DT}
return NTuple{M * K ÷ mb, CT}
end
@inline function fragtype_b(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, mb, nb, kb, CT, AT, DT}
return NTuple{K * N ÷ nb, CT}
end

@inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, AT, DT}
@inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}) where {M, N, K, mb, nb, kb, CT, AT, DT}
return NTuple{M * N ÷ 32, AT}
end

@inline function load_a(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, AT, DT}
@inline function load_a(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, mb, nb, kb, CT, AT, DT}
laneId = (threadIdx().x - 1) % 32 + 1

op_y = (laneId - 1) % 4 + 1
op_y = (laneId - 1) % mb + 1
y, x = (tile.base.M + tile.offset.M + op_y, tile.base.K + tile.offset.K + 1)

frag = LocalArray{Tuple{M ÷ 4, K}, CT}(undef)
@loopinfo unroll for m = 1 : M ÷ 4
frag = LocalArray{Tuple{M ÷ mb, K}, CT}(undef)
@loopinfo unroll for m = 1 : M ÷ mb
@loopinfo unroll for k = 1 : K
y_layout, x_layout = $convert_index_func((y + 4 * (m - 1), x + (k - 1)))
y_layout, x_layout = $convert_index_func((y + mb * (m - 1), x + (k - 1)))
@inbounds @immutable frag[m,k] = workspace[y_layout, x_layout]
end
end

return NTuple{M * K ÷ 4, CT}(frag)
return NTuple{M * K ÷ mb, CT}(frag)
end

@inline function load_b(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, AT, DT}
@inline function load_b(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, mb, nb, kb, CT, AT, DT}
laneId = (threadIdx().x - 1) % 32 + 1

op_x = (laneId - 1) ÷ 4 + 1
op_x = (laneId - 1) ÷ mb + 1
y, x = (tile.base.K + tile.offset.K + 1, tile.base.N + tile.offset.N + op_x)

frag = LocalArray{Tuple{K, N ÷ 8}, CT}(undef)
@loopinfo unroll for n = 1 : N ÷ 8
frag = LocalArray{Tuple{K, N ÷ nb}, CT}(undef)
@loopinfo unroll for n = 1 : N ÷ nb
@loopinfo unroll for k = 1 : K
y_layout, x_layout = $convert_index_func((y + (k - 1), x + 8 * (n - 1)))
y_layout, x_layout = $convert_index_func((y + (k - 1), x + nb * (n - 1)))
@inbounds @immutable frag[k,n] = workspace[y_layout, x_layout]
end
end

return NTuple{K * N ÷ 8, CT}(frag)
return NTuple{K * N ÷ nb, CT}(frag)
end

@inline function load_c(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, AT, DT}
@inline function load_c(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, mb, nb, kb, CT, AT, DT}
laneId = (threadIdx().x - 1) % 32 + 1

op_y = (laneId - 1) % 4 + 1
op_x = (laneId - 1) ÷ 4 + 1
op_y = (laneId - 1) % mb + 1
op_x = (laneId - 1) ÷ mb + 1

y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x)

frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, AT}(undef)
@loopinfo unroll for m = 1 : M ÷ 4
@loopinfo unroll for n = 1 : N ÷ 8
@inbounds @immutable frag[m,n] = workspace[y + 4 * (m - 1), x + 8 * (n - 1)]

frag = LocalArray{Tuple{M ÷ mb, N ÷ nb}, AT}(undef)
@loopinfo unroll for m = 1 : M ÷ mb
@loopinfo unroll for n = 1 : N ÷ nb
@inbounds @immutable frag[m,n] = workspace[y + mb * (m - 1), x + nb * (n - 1)]
end
end

return NTuple{M * N ÷ 32, AT}(frag)
end

@inline function store_d(::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT, DT}
@inline function store_d(::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, mb, nb, kb, CT, AT, DT}
laneId = (threadIdx().x - 1) % 32 + 1

op_y = (laneId - 1) % 4 + 1
op_x = (laneId - 1) ÷ 4 + 1
op_y = (laneId - 1) % mb + 1
op_x = (laneId - 1) ÷ mb + 1

y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x)

frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, AT}(frag)
@loopinfo unroll for m = 1 : M ÷ 4
@loopinfo unroll for n = 1 : N ÷ 8
@inbounds workspace[y + 4 * (m - 1), x + 8 * (n - 1)] = frag[m, n]
frag = LocalArray{Tuple{M ÷ mb, N ÷ nb}, AT}(frag)
@loopinfo unroll for m = 1 : M ÷ mb
@loopinfo unroll for n = 1 : N ÷ nb
@inbounds workspace[y + mb * (m - 1), x + nb * (n - 1)] = frag[m,n]
end
end
end
end
end

abstract type FPUOp{M, N, K, CT, AT} <: GeneralFPUOp{M, N, K, CT, AT} end
function operator_fma(::Type{FPUOp{M, N, K, CT, AT}}, a::CT, b::CT, c::AT) where {M, N, K, CT, AT}
abstract type FPUOp{M, N, K, mb, nb, kb, CT, AT} <: GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT} end
function operator_fma(::Type{FPUOp{M, N, K, mb, nb, kb, CT, AT}}, a::CT, b::CT, c::AT) where {M, N, K, mb, nb, kb, CT, AT}
return fma(a, b, c)
end

abstract type TropicalFPUOp{M, N, K, CT, AT} <: GeneralFPUOp{M, N, K, CT, AT} end
abstract type TropicalFPUOp{M, N, K, CT, AT} <: GeneralFPUOp{M, N, K, 4, 8, 1, CT, AT} end
function operator_fma(::Type{TropicalFPUOp{M, N, K, CT, AT}}, a::CT, b::CT, c::AT) where {M, N, K, CT, AT}
return max(a + b, c)
end

@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT}
a_frag = LocalArray{Tuple{M ÷ 4, K}, CT}(a_frag)
b_frag = LocalArray{Tuple{K, N ÷ 8}, CT}(b_frag)
c_frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, AT}(c_frag)
@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, mb, nb, kb, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, mb, nb, kb, CT, AT}
a_frag = LocalArray{Tuple{M ÷ mb, K}, CT}(a_frag)
b_frag = LocalArray{Tuple{K, N ÷ nb}, CT}(b_frag)
c_frag = LocalArray{Tuple{M ÷ mb, N ÷ nb}, AT}(c_frag)

@loopinfo unroll for m = 1 : M ÷ 4
@loopinfo unroll for n = 1 : N ÷ 8
@loopinfo unroll for m = 1 : M ÷ mb
@loopinfo unroll for n = 1 : N ÷ nb
@loopinfo unroll for k = 1 : K
@inbounds @immutable c_frag[m,n] = operator_fma(operator_type, a_frag[m, k], b_frag[k, n], c_frag[m, n])
end
Expand Down

0 comments on commit 82bc0d2

Please sign in to comment.