Skip to content

Commit

Permalink
Use multiple dispatch to check operator
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasfaingnaert committed Oct 24, 2023
1 parent 530b7a7 commit 80bacb1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/GemmKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ include("array.jl")
include("utils.jl")

# framework
include("layout.jl")
include("operator.jl")
include("config.jl")
include("epilogue.jl")
include("kernel.jl")
include("layout.jl")
include("operator.jl")
include("transform.jl")

# instantiations
Expand Down
12 changes: 6 additions & 6 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
is_b_col_major
end

struct ConfigError <: Exception
struct ConfigError <: Exception
message::String
end

Expand Down Expand Up @@ -93,7 +93,9 @@ function adjacent_elements(num, parent_size, is_col_major)
return typeof(parent_size)(t)
end

function handle_operator_config(operator)
check_operator_config(::Type{T}) where {T} = nothing

function check_operator_config(::Type{<:Operator.GeneralFPUOp})
op_shape = Operator.base_shape(operator)

# The 32 threads in a warp must at least handle one element of the operator.
Expand Down Expand Up @@ -148,9 +150,7 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
throw(ConfigError("There is a mismatch between the block shape and the operator shape. Their dimensions must adhere to the following constraints: BLOCK_M ≥ 2 * OPERATOR_M, BLOCK_N ≥ 2 * OPERATOR_N."))
end

if operator <: Operator.GeneralFPUOp
handle_operator_config(operator)
end
check_operator_config(operator)

# 8 warps in a 4 x 2 arrangement usually works well
warps_per_block_default = 8
Expand All @@ -162,7 +162,7 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
M = block_shape.M ÷ min(block_shape.M ÷ op_shape.M, 4),
N = block_shape.N ÷ min(block_shape.N ÷ op_shape.N, 2),
K = op_shape.K
)
)
warps_per_block_default = min(block_shape.M ÷ op_shape.M, 4) * min(block_shape.N ÷ op_shape.N, 2)
end

Expand Down

0 comments on commit 80bacb1

Please sign in to comment.