Skip to content

Commit

Permalink
Check tile sizes in config
Browse files Browse the repository at this point in the history
Extracted from #179
  • Loading branch information
thomasfaingnaert committed Dec 7, 2023
1 parent c84a5ac commit 7e244b0
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ check_operator_config(operator::Type{<:Operator.WMMAOp}) = check_wmma_shape(oper
check_operator_config(operator::Type{<:Operator.WMMAComplexOp}) = check_wmma_shape(operator)
check_operator_config(operator::Type{<:Operator.WMMADualOp}) = check_wmma_shape(operator)

require_tile_sized_global(layout) = true
require_tile_sized_global(::Type{<:Layout.Zero{T}}) where {T} = false
require_tile_sized_global(::Type{<:Layout.ColMajor{T}}) where {T} = false
require_tile_sized_global(::Type{<:Layout.RowMajor{T}}) where {T} = false

function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kwargs...)
params = Dict(kwargs)

Expand Down Expand Up @@ -215,6 +220,15 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
prod(mem_b_warp) * warps_per_block block_shape.K * block_shape.N || throw(ConfigError("mem_b_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!"))
prod(mem_cd_warp) * warps_per_block block_shape.M * block_shape.N || throw(ConfigError("mem_cd_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!"))

# Check sizes of tiles
check_tile_multiple(num, den, dims, msg) = all([num[dim] % den[dim] == 0 for dim in dims]) || throw(ConfigError(msg))

check_tile_multiple(block_shape, compute_warp, [:M, :N, :K], "block_shape must be a multiple of compute_warp!")
require_tile_sized_global(global_a_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :K], "gemm_shape.MK must be a multiple of block_shape.MK!")
require_tile_sized_global(global_b_layout) && check_tile_multiple(gemm_shape, block_shape, [:K, :N], "gemm_shape.KN must be a multiple of block_shape.KN!")
require_tile_sized_global(global_c_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!")
require_tile_sized_global(global_d_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!")

return Config(
#= Params =#
gemm_shape,
Expand Down

0 comments on commit 7e244b0

Please sign in to comment.