Skip to content

Commit

Permalink
Improve heuristic for memcopy tile sizes (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasfaingnaert authored Nov 14, 2023
1 parent 8a73476 commit 433aa68
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,28 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
# 1) The tiles should encompass 128 bits (16 bytes) to enable vectorisation.
# 2) The tiles should be as small as possible (i.e. each thread exactly 128 bits) to enable coalescing.

num_elems_per_thread_a = min(16 ÷ sizeof(Layout.eltype(global_a_layout)), (block_shape.M * block_shape.K) ÷ (32 * warps_per_block))
num_elems_per_thread_b = min(16 ÷ sizeof(Layout.eltype(global_b_layout)), (block_shape.K * block_shape.N) ÷ (32 * warps_per_block))
num_elems_per_thread_c = min(16 ÷ sizeof(Layout.eltype(global_c_layout)), (block_shape.M * block_shape.N) ÷ (32 * warps_per_block))

mem_a_warp = get(params, :mem_a_warp,
adjacent_elements(32 * 16 ÷ sizeof(Layout.eltype(global_a_layout)), (M = block_shape.M, K = block_shape.K), is_a_col_major))
adjacent_elements(32 * num_elems_per_thread_a, (M = block_shape.M, K = block_shape.K), is_a_col_major))
mem_b_warp = get(params, :mem_b_warp,
adjacent_elements(32 * 16 ÷ sizeof(Layout.eltype(global_b_layout)), (K = block_shape.K, N = block_shape.N), is_b_col_major))
adjacent_elements(32 * num_elems_per_thread_b, (K = block_shape.K, N = block_shape.N), is_b_col_major))
mem_cd_warp = get(params, :mem_cd_warp,
adjacent_elements(32 * 16 ÷ sizeof(Layout.eltype(global_c_layout)), (M = block_shape.M, N = block_shape.N), is_cd_col_major))
adjacent_elements(32 * num_elems_per_thread_c, (M = block_shape.M, N = block_shape.N), is_cd_col_major))

mem_a_thread = get(params, :mem_a_thread,
adjacent_elements(16 ÷ sizeof(Layout.eltype(global_a_layout)), (M = block_shape.M, K = block_shape.K), is_a_col_major))
adjacent_elements(num_elems_per_thread_a, (M = block_shape.M, K = block_shape.K), is_a_col_major))
mem_b_thread = get(params, :mem_b_thread,
adjacent_elements(16 ÷ sizeof(Layout.eltype(global_b_layout)), (K = block_shape.K, N = block_shape.N), is_b_col_major))
adjacent_elements(num_elems_per_thread_b, (K = block_shape.K, N = block_shape.N), is_b_col_major))
mem_cd_thread = get(params, :mem_cd_thread,
adjacent_elements(16 ÷ sizeof(Layout.eltype(global_c_layout)), (M = block_shape.M, N = block_shape.N), is_cd_col_major))
adjacent_elements(num_elems_per_thread_c, (M = block_shape.M, N = block_shape.N), is_cd_col_major))

# Make sure that we have at least one iteration in the memory copy loops.
prod(mem_a_warp) * warps_per_block block_shape.M * block_shape.K || throw(ConfigError("mem_a_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!"))
prod(mem_b_warp) * warps_per_block block_shape.K * block_shape.N || throw(ConfigError("mem_b_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!"))
prod(mem_cd_warp) * warps_per_block block_shape.M * block_shape.N || throw(ConfigError("mem_cd_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!"))

return Config(
#= Params =#
Expand Down

0 comments on commit 433aa68

Please sign in to comment.