diff --git a/src/GemmKernels.jl b/src/GemmKernels.jl index 60815e18..a6073d62 100644 --- a/src/GemmKernels.jl +++ b/src/GemmKernels.jl @@ -9,11 +9,11 @@ include("array.jl") include("utils.jl") # framework +include("layout.jl") +include("operator.jl") include("config.jl") include("epilogue.jl") include("kernel.jl") -include("layout.jl") -include("operator.jl") include("transform.jl") # instantiations diff --git a/src/config.jl b/src/config.jl index b8dda0db..1bb6ee79 100644 --- a/src/config.jl +++ b/src/config.jl @@ -35,7 +35,7 @@ is_b_col_major end -struct ConfigError <: Exception +struct ConfigError <: Exception message::String end @@ -93,7 +93,9 @@ function adjacent_elements(num, parent_size, is_col_major) return typeof(parent_size)(t) end -function handle_operator_config(operator) +check_operator_config(::Type{T}) where {T} = nothing + +function check_operator_config(::Type{<:Operator.GeneralFPUOp}) op_shape = Operator.base_shape(operator) # The 32 threads in a warp must at least handle one element of the operator. @@ -148,9 +150,7 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw throw(ConfigError("There is a mismatch between the block shape and the operator shape. Their dimensions must adhere to the following constraints: BLOCK_M ≥ 2 * OPERATOR_M, BLOCK_N ≥ 2 * OPERATOR_N.")) end - if operator <: Operator.GeneralFPUOp - handle_operator_config(operator) - end + check_operator_config(operator) # 8 warps in a 4 x 2 arrangement usually works well warps_per_block_default = 8 @@ -162,7 +162,7 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw M = block_shape.M ÷ min(block_shape.M ÷ op_shape.M, 4), N = block_shape.N ÷ min(block_shape.N ÷ op_shape.N, 2), K = op_shape.K - ) + ) warps_per_block_default = min(block_shape.M ÷ op_shape.M, 4) * min(block_shape.N ÷ op_shape.N, 2) end