Skip to content

Commit

Permalink
Use a zero layout for C in shared memory if beta=0 (#194)
Browse files Browse the repository at this point in the history
While we already avoid a global load from C in case beta == 0, we still
emit stores to shared memory and loads from shared memory for C.

Instead, we should also use a zero layout for C in shared memory, which
eliminates these extra loads and stores.

This does not seem to influence the performance of GEMM, even for small
matrices, or highly rectangular GEMMs with small K, but it does make a
difference for some TCs, I've noticed, so let's do this, anyway.
  • Loading branch information
thomasfaingnaert authored May 6, 2024
1 parent c98fa10 commit 4339926
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion configs/configs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ macro get_wmma_config()

shared_a_layout = Layout.Padded{transpose_a ? Layout.UnsafeAlignedRowMajor{AB_type} : Layout.UnsafeAlignedColMajor{AB_type}, 16 ÷ sizeof(AB_type)},
shared_b_layout = Layout.Padded{transpose_b ? Layout.UnsafeAlignedRowMajor{AB_type} : Layout.UnsafeAlignedColMajor{AB_type}, 16 ÷ sizeof(AB_type)},
shared_c_layout = Layout.UnsafeAlignedColMajor{CD_type},
shared_c_layout = zero_c ? Layout.Zero{CD_type} : Layout.UnsafeAlignedColMajor{CD_type},
shared_d_layout = Layout.UnsafeAlignedColMajor{CD_type},

operator = Operator.WMMAOp{OP_M, OP_N, OP_K, AB_type, CD_type},
Expand Down
7 changes: 6 additions & 1 deletion src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ end
shared_b_layout = Layout.Padded{b_aligned_layout_base{eltype(B)}, 8}
end
## outputs are never transposed, and padding them doesn't seem worth it
shared_c_layout = shared_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}
shared_c_layout = if zeroBeta
Layout.Zero{eltype(C)}
else
Layout.UnsafeAlignedColMajor{eltype(C)}
end
shared_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}

# determine block shape
# XXX: heuristic should take much more into account (GEMM size, at least)
Expand Down
13 changes: 13 additions & 0 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ for f in (:fragtype_a, :fragtype_b, :fragtype_accum, :load_a, :load_b, :load_c,
@eval @inline $f(op, ::Type{Layout.Padded{L, P}}, args...) where {L, P} = $f(op, L, args...)
end

# -----------------------------------
# Default definition for zero layouts
# -----------------------------------

for f in (:fragtype_a, :fragtype_b, :fragtype_accum)
@eval @inline $f(op, ::Type{Layout.Zero{T}}, args...) where {T} = $f(op, Layout.UnsafeAlignedColMajor{T}, args...)
end


# ---
# FPU
Expand Down Expand Up @@ -226,6 +234,11 @@ for (layout_type, wmma_layout_type, convert_index_func) in [
end
end

@inline function load_c(::Type{WMMAOp{M, N, K, CT, AT}}, ::Type{Layout.Zero{AT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
return WMMA.fill_c(zero(AT), conf)
end

function mma(::Type{WMMAOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
return WMMA.mma(a_frag, b_frag, c_frag, conf)
Expand Down

0 comments on commit 4339926

Please sign in to comment.