Skip to content

Commit

Permalink
Apply formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Mar 8, 2024
1 parent 4852fe5 commit 77dbc0b
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 235 deletions.
20 changes: 10 additions & 10 deletions perf/flame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ include(joinpath(pkgdir(MBF), "test", "utils.jl"))
has_cuda = CUDA.has_cuda()
AType = has_cuda ? CUDA.CuArray : Array
# arr_size = (prod((50,5,5,6,50)),)
arr_size = (50,5,5,6,50)
arr_size = (50, 5, 5, 6, 50)
X = get_arrays(:x, arr_size, AType)
Y = get_arrays(:y, arr_size, AType)

function perf_kernel_fused!(X, Y)
(;x1,x2,x3,x4,x5,x6,x7,x8,x9,x10) = X
(;y1,y2,y3,y4,y5,y6,y7,y8,y9,y10) = Y
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
MBF.@fused begin
@. y1 = x1+x2+x3+x4
@. y2 = x2+x3+x4+x5
@. y3 = x3+x4+x5+x6
@. y4 = x4+x5+x6+x7
@. y5 = x5+x6+x7+x8
@. y6 = x6+x7+x8+x9
@. y7 = x7+x8+x9+x10
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
@. y3 = x3 + x4 + x5 + x6
@. y4 = x4 + x5 + x6 + x7
@. y5 = x5 + x6 + x7 + x8
@. y6 = x6 + x7 + x8 + x9
@. y7 = x7 + x8 + x9 + x10
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/MultiBroadcastFusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ abstract type AbstractFusedMultiBroadcast end
A mult-broadcast fusion object
"""
struct FusedMultiBroadcast{T} <: AbstractFusedMultiBroadcast
pairs::T
pairs::T
end

# Base.@propagate_inbounds function rcopyto_at!(pair::Pair, i::CartesianIndex)
Expand All @@ -25,7 +25,7 @@ end
# return nothing
# end
Base.@propagate_inbounds function rcopyto_at!(pair::Pair, i...)
dest,src = pair.first, pair.second
dest, src = pair.first, pair.second
@inbounds dest[i...] = src[i...]
return nothing
end
Expand Down
105 changes: 52 additions & 53 deletions src/macro_utils.jl
Original file line number Diff line number Diff line change
@@ -1,83 +1,82 @@

function materialize_args(expr::Expr)
@assert expr.head == :call
@assert expr.args[1] == :(Base.materialize!)
return (expr.args[2], expr.args[3])
@assert expr.head == :call
@assert expr.args[1] == :(Base.materialize!)
return (expr.args[2], expr.args[3])
end

function fused(expr)
end
function fused(expr) end

macro fused(expr)
_pairs = gensym()
quote
$_pairs = $(esc(fused_pairs(expr)))
Base.copyto!(FusedMultiBroadcast($_pairs))
end
_pairs = gensym()
quote
$_pairs = $(esc(fused_pairs(expr)))
Base.copyto!(FusedMultiBroadcast($_pairs))
end
end

macro fused_pairs(expr)
esc(fused_pairs(expr))
esc(fused_pairs(expr))
end

function _fused_pairs(expr::Expr)
@assert expr.head == :block
exprs_out = []
for _expr in expr.args
# TODO: should we retain LineNumberNode?
_expr isa LineNumberNode && continue
@assert _expr isa Expr
if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__")
se = code_lowered_single_expression(_expr)
margs = materialize_args(se)
push!(exprs_out, :(Pair($(margs[1]), $(margs[2]))))
end
end
if length(exprs_out) == 1
return "($(exprs_out[1]),)"
else
return "("*join(exprs_out, ",")*")"
end
@assert expr.head == :block
exprs_out = []
for _expr in expr.args
# TODO: should we retain LineNumberNode?
_expr isa LineNumberNode && continue
@assert _expr isa Expr
if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__")
se = code_lowered_single_expression(_expr)
margs = materialize_args(se)
push!(exprs_out, :(Pair($(margs[1]), $(margs[2]))))
end
end
if length(exprs_out) == 1
return "($(exprs_out[1]),)"
else
return "(" * join(exprs_out, ",") * ")"
end
end

fused_pairs(expr::Expr) = Meta.parse(_fused_pairs(expr))

macro fused_multibroadcast(expr)
esc(fused_multibroadcast("MultiBroadcastFusion.FusedMultiBroadcast", expr))
esc(fused_multibroadcast("MultiBroadcastFusion.FusedMultiBroadcast", expr))
end

macro fused_multibroadcast(fmb, expr)
esc(fused_multibroadcast(fmb, expr))
esc(fused_multibroadcast(fmb, expr))
end
fused_multibroadcast(fmb, expr::Expr) =
Meta.parse("$(fmb)($(_fused_pairs(expr)))")
Meta.parse("$(fmb)($(_fused_pairs(expr)))")

function build_expr(s::String, code_remain)
n_subs = count("%", s)
if n_subs > 0
while n_subs > 0
regex = r"%[0-9]"
m = match(regex, s)
smatch = m.match
j = Meta.parse(smatch[2:end])
s = replace(s, smatch => string(code_remain[j]))
n_subs = count("%", s)
end
else
return s
end
return s
n_subs = count("%", s)
if n_subs > 0
while n_subs > 0
regex = r"%[0-9]"
m = match(regex, s)
smatch = m.match
j = Meta.parse(smatch[2:end])
s = replace(s, smatch => string(code_remain[j]))
n_subs = count("%", s)
end
else
return s
end
return s
end

build_expr(code::Vector) = build_expr(string(code[end]), code)

function code_lowered_single_expression(expr)
code_lowered = Base.Meta.lower(Main, expr)
code_info = code_lowered.args[1]
code = code_info.code # vector
s = build_expr(code)
if startswith(s, "return ")
s = replace(s, "return " => "")
end
return Base.Meta.parse(s)
code_lowered = Base.Meta.lower(Main, expr)
code_info = code_lowered.args[1]
code = code_info.code # vector
s = build_expr(code)
if startswith(s, "return ")
s = replace(s, "return " => "")
end
return Base.Meta.parse(s)
end
80 changes: 42 additions & 38 deletions test/bm_fused_reads_vs_hard_coded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ include("utils.jl")
perf_kernel_hard_coded!(X, Y) = perf_kernel_hard_coded!(X, Y, device(X.x1))

function perf_kernel_hard_coded!(X, Y, ::CPU)
(;x1,x2,x3,x4,x5,x6,x7,x8,x9,x10) = X
(;y1,y2,y3,y4,y5,y6,y7,y8,y9,y10) = Y
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
@inbounds for i in eachindex(x1)
y1[i] = x1[i]+x2[i]+x3[i]+x4[i]
y2[i] = x2[i]+x3[i]+x4[i]+x5[i]
y3[i] = x3[i]+x4[i]+x5[i]+x6[i]
y4[i] = x4[i]+x5[i]+x6[i]+x7[i]
y5[i] = x5[i]+x6[i]+x7[i]+x8[i]
y6[i] = x6[i]+x7[i]+x8[i]+x9[i]
y7[i] = x7[i]+x8[i]+x9[i]+x10[i]
y1[i] = x1[i] + x2[i] + x3[i] + x4[i]
y2[i] = x2[i] + x3[i] + x4[i] + x5[i]
y3[i] = x3[i] + x4[i] + x5[i] + x6[i]
y4[i] = x4[i] + x5[i] + x6[i] + x7[i]
y5[i] = x5[i] + x6[i] + x7[i] + x8[i]
y6[i] = x6[i] + x7[i] + x8[i] + x9[i]
y7[i] = x7[i] + x8[i] + x9[i] + x10[i]
end
end
function perf_kernel_hard_coded!(X, Y, ::GPU)
Expand All @@ -25,21 +25,25 @@ function perf_kernel_hard_coded!(X, Y, ::GPU)
max_threads = 256 # can be higher if conditions permit
nthreads = min(max_threads, nitems)
nblocks = cld(nitems, nthreads)
CUDA.@cuda threads = (nthreads) blocks = (nblocks) knl_multi_copyto_hard_coded!(X, Y, Val(nitems))
CUDA.@cuda threads = (nthreads) blocks = (nblocks) knl_multi_copyto_hard_coded!(
X,
Y,
Val(nitems),
)
end
function knl_multi_copyto_hard_coded!(X, Y, ::Val{nitems}) where {nitems}
(;x1,x2,x3,x4,x5,x6,x7,x8,x9,x10) = X
(;y1,y2,y3,y4,y5,y6,y7,y8,y9,y10) = Y
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
gidx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
if gidx < nitems
idx = gidx
y1[idx] = x1[idx]+x2[idx]+x3[idx]+x4[idx]
y2[idx] = x2[idx]+x3[idx]+x4[idx]+x5[idx]
y3[idx] = x3[idx]+x4[idx]+x5[idx]+x6[idx]
y4[idx] = x4[idx]+x5[idx]+x6[idx]+x7[idx]
y5[idx] = x5[idx]+x6[idx]+x7[idx]+x8[idx]
y6[idx] = x6[idx]+x7[idx]+x8[idx]+x9[idx]
y7[idx] = x7[idx]+x8[idx]+x9[idx]+x10[idx]
y1[idx] = x1[idx] + x2[idx] + x3[idx] + x4[idx]
y2[idx] = x2[idx] + x3[idx] + x4[idx] + x5[idx]
y3[idx] = x3[idx] + x4[idx] + x5[idx] + x6[idx]
y4[idx] = x4[idx] + x5[idx] + x6[idx] + x7[idx]
y5[idx] = x5[idx] + x6[idx] + x7[idx] + x8[idx]
y6[idx] = x6[idx] + x7[idx] + x8[idx] + x9[idx]
y7[idx] = x7[idx] + x8[idx] + x9[idx] + x10[idx]
end
return nothing
end
Expand All @@ -48,37 +52,37 @@ end

has_cuda = CUDA.has_cuda()
AType = has_cuda ? CUDA.CuArray : Array
arr_size = (prod((50,5,5,6,50)),)
arr_size = (prod((50, 5, 5, 6, 50)),)
# arr_size = (50,5,5,6,50)
X = get_arrays(:x, arr_size, AType);
Y = get_arrays(:y, arr_size, AType);

function perf_kernel_unfused!(X, Y)
(;x1,x2,x3,x4,x5,x6,x7,x8,x9,x10) = X
(;y1,y2,y3,y4,y5,y6,y7,y8,y9,y10) = Y
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
# 7 writes; 10 unique reads
# 7 writes; 28 reads including redundant ones
@. y1 = x1+x2+x3+x4
@. y2 = x2+x3+x4+x5
@. y3 = x3+x4+x5+x6
@. y4 = x4+x5+x6+x7
@. y5 = x5+x6+x7+x8
@. y6 = x6+x7+x8+x9
@. y7 = x7+x8+x9+x10
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
@. y3 = x3 + x4 + x5 + x6
@. y4 = x4 + x5 + x6 + x7
@. y5 = x5 + x6 + x7 + x8
@. y6 = x6 + x7 + x8 + x9
@. y7 = x7 + x8 + x9 + x10
return nothing
end

function perf_kernel_fused!(X, Y)
(;x1,x2,x3,x4,x5,x6,x7,x8,x9,x10) = X;
(;y1,y2,y3,y4,y5,y6,y7,y8,y9,y10) = Y;
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
MBF.@fused begin
@. y1 = x1+x2+x3+x4
@. y2 = x2+x3+x4+x5
@. y3 = x3+x4+x5+x6
@. y4 = x4+x5+x6+x7
@. y5 = x5+x6+x7+x8
@. y6 = x6+x7+x8+x9
@. y7 = x7+x8+x9+x10
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
@. y3 = x3 + x4 + x5 + x6
@. y4 = x4 + x5 + x6 + x7
@. y5 = x5 + x6 + x7 + x8
@. y6 = x6 + x7 + x8 + x9
@. y7 = x7 + x8 + x9 + x10
end
end

Expand Down
42 changes: 22 additions & 20 deletions test/bm_fused_shared_reads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,43 @@ using Revise; include(joinpath("test", "fused.jl"))
include("utils.jl")

function perf_kernel_shared_reads_unfused!(X, Y)
(;x1,x2,x3,x4,x5,x6,x7,x8,x9,x10) = X
(;y1,y2,y3,y4,y5,y6,y7,y8,y9,y10) = Y
@. y1 = x1+x2+x3+x4
@. y2 = x2+x3+x4+x5
@. y3 = x3+x4+x5+x6
@. y4 = x4+x5+x6+x7
@. y5 = x5+x6+x7+x8
@. y6 = x6+x7+x8+x9
@. y7 = x7+x8+x9+x10
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
@. y3 = x3 + x4 + x5 + x6
@. y4 = x4 + x5 + x6 + x7
@. y5 = x5 + x6 + x7 + x8
@. y6 = x6 + x7 + x8 + x9
@. y7 = x7 + x8 + x9 + x10
end

function perf_kernel_shared_reads_fused!(X, Y)
(;x1,x2,x3,x4,x5,x6,x7,x8,x9,x10) = X
(;y1,y2,y3,y4,y5,y6,y7,y8,y9,y10) = Y
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
MBF.@fused begin
@. y1 = x1+x2+x3+x4
@. y2 = x2+x3+x4+x5
@. y3 = x3+x4+x5+x6
@. y4 = x4+x5+x6+x7
@. y5 = x5+x6+x7+x8
@. y6 = x6+x7+x8+x9
@. y7 = x7+x8+x9+x10
@. y1 = x1 + x2 + x3 + x4
@. y2 = x2 + x3 + x4 + x5
@. y3 = x3 + x4 + x5 + x6
@. y4 = x4 + x5 + x6 + x7
@. y5 = x5 + x6 + x7 + x8
@. y6 = x6 + x7 + x8 + x9
@. y7 = x7 + x8 + x9 + x10
end
end

has_cuda = CUDA.has_cuda()
AType = has_cuda ? CUDA.CuArray : Array
arr_size = (prod((50,5,5,6,50)),)
arr_size = (prod((50, 5, 5, 6, 50)),)
X = get_arrays(:x, arr_size, AType)
Y = get_arrays(:y, arr_size, AType)

test_kernel!(;
fused! = perf_kernel_shared_reads_fused!,
unfused! = perf_kernel_shared_reads_unfused!,
X, Y)
X,
Y,
)
# Compile
perf_kernel_shared_reads_unfused!(X, Y)
perf_kernel_shared_reads_fused!(X, Y)
Expand Down
Loading

0 comments on commit 77dbc0b

Please sign in to comment.