Skip to content

Commit

Permalink
Fix test bugs, improve tests, restrict loops, func calls, and if-else
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Mar 9, 2024
1 parent 10aeac4 commit 960cb7d
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 54 deletions.
20 changes: 16 additions & 4 deletions src/macro_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,27 @@ function _fused_pairs(expr::Expr)
exprs_out = []
for _expr in expr.args
# TODO: should we retain LineNumberNode?
if _expr isa Symbol # ????????
return ""
end
# if _expr isa Symbol # ????????
# error("???")
# return ""
# end
_expr isa LineNumberNode && continue
# @assert _expr isa Expr
if _expr.head == :for
error("Loops are not allowed inside fused blocks")
end
if _expr.head == :if
error("If-statements are not allowed inside fused blocks")
end
if _expr.head == :call
error("Function calls are not allowed inside fused blocks")
end
if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__")
se = code_lowered_single_expression(_expr)
margs = materialize_args(se)
push!(exprs_out, :(Pair($(margs[1]), $(margs[2]))))
else
@show _expr
error("Uncaught edge case")

Check warning on line 37 in src/macro_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/macro_utils.jl#L36-L37

Added lines #L36 - L37 were not covered by tests
end
end
if length(exprs_out) == 1
Expand Down
54 changes: 23 additions & 31 deletions test/bm_fused_reads_vs_hard_coded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,37 +34,16 @@ end
function knl_multi_copyto_hard_coded!(X, Y, ::Val{nitems}) where {nitems}
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
gidx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
if gidx < nitems
idx = gidx
y1[idx] = x1[idx] + x2[idx] + x3[idx] + x4[idx]
y2[idx] = x2[idx] + x3[idx] + x4[idx] + x5[idx]
y3[idx] = x3[idx] + x4[idx] + x5[idx] + x6[idx]
y4[idx] = x4[idx] + x5[idx] + x6[idx] + x7[idx]
y5[idx] = x5[idx] + x6[idx] + x7[idx] + x8[idx]
y6[idx] = x6[idx] + x7[idx] + x8[idx] + x9[idx]
y7[idx] = x7[idx] + x8[idx] + x9[idx] + x10[idx]
end
return nothing
end

function knl_multi_copyto_hard_coded!(X, Y, ::Val{nitems}) where {nitems}
(; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X
(; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y
gidx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
if gidx < nitems
@inbounds begin
idx = gidx
y1[idx] = x1[idx] + x6[idx]
y2[idx] = x2[idx] + x7[idx]
y3[idx] = x3[idx] + x8[idx]
y4[idx] = x4[idx] + x9[idx]
y5[idx] = x5[idx] + x10[idx]
y6[idx] = y1[idx]
y7[idx] = y2[idx]
y8[idx] = y3[idx]
y9[idx] = y4[idx]
y10[idx] = y5[idx]
idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
@inbounds begin
if idx nitems
y1[idx] = x1[idx] + x2[idx] + x3[idx] + x4[idx]
y2[idx] = x2[idx] + x3[idx] + x4[idx] + x5[idx]
y3[idx] = x3[idx] + x4[idx] + x5[idx] + x6[idx]
y4[idx] = x4[idx] + x5[idx] + x6[idx] + x7[idx]
y5[idx] = x5[idx] + x6[idx] + x7[idx] + x8[idx]
y6[idx] = x6[idx] + x7[idx] + x8[idx] + x9[idx]
y7[idx] = x7[idx] + x8[idx] + x9[idx] + x10[idx]
end
end
return nothing
Expand Down Expand Up @@ -108,6 +87,19 @@ function perf_kernel_fused!(X, Y)
end
end

test_kernel!(;
fused! = perf_kernel_fused!,
unfused! = perf_kernel_unfused!,
X,
Y,
)
test_kernel!(;
fused! = perf_kernel_hard_coded!,
unfused! = perf_kernel_unfused!,
X,
Y,
)

# Compile
perf_kernel_unfused!(X, Y)
perf_kernel_fused!(X, Y)
Expand Down
85 changes: 85 additions & 0 deletions test/expr_errors_and_edge_cases.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#=
using Revise; include(joinpath("test", "expr_errors_and_edge_cases.jl"))
=#
using Test
import MultiBroadcastFusion as MBF
@testset "loops" begin
# loops are not allowed, because
# code transformation occurs at macro
# expansion time, and we can't generally
# know how many times the loop will be
# executed at this time.

# We could try to specialize on literal ranges, e.g.,
# `for i in 1:10`, but that is likely an uncommon
# edge case.

expr_in = quote
@. y1 = x1 + x2 + x3 + x4
for i in 1:10
@. y2 = x2 + x3 + x4 + x5
end
@. y1 = x1 + x2 + x3 + x4
end
@test_throws ErrorException("Loops are not allowed inside fused blocks") MBF.fused_pairs(
expr_in,
)
end

struct Foo end

@testset "If-statements" begin
# If-statements are not allowed, because
# code transformation occurs at macro
# expansion time, and Bools, even types,
# are not known at this time.

# We could specialize on literals, e.g.,
# `if true`, but that is likely an uncommon
# edge case.
foo = Foo()
expr_in = quote
@. y1 = x1 + x2 + x3 + x4
if foo isa Foo
@. y2 = x2 + x3 + x4 + x5
end
@. y1 = x1 + x2 + x3 + x4
end
@test_throws ErrorException(
"If-statements are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
end

bar() = nothing
@testset "Function calls" begin
# Function calls are not allowed, because
# this could lead to subtle bugs (order of compute).
expr_in = quote
@. y1 = x1 + x2 + x3 + x4
bar()
@. y1 = x1 + x2 + x3 + x4
end
@test_throws ErrorException(
"Function calls are not allowed inside fused blocks",
) MBF.fused_pairs(expr_in)
end

@testset "Comments" begin
expr_in = quote
@. y1 = x1 + x2 + x3 + x4
# Foo bar baz
# if i in 1:N
@. y2 = x2 + x3 + x4 + x5
end

expr_out = :((
Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)),
Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)),
))
@test MBF.fused_pairs(expr_in) == expr_out
end

@testset "Empty" begin
expr_in = quote end
@test MBF.fused_pairs(expr_in) == :(())
end
28 changes: 9 additions & 19 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,13 @@ using Revise; include(joinpath("test", "runtests.jl"))
using Test
using SafeTestsets

@safetestset "expr_code_lowered_single_expression" begin
@time include("expr_code_lowered_single_expression.jl")
end
@safetestset "expr_materialize_args" begin
@time include("expr_materialize_args.jl")
end
@safetestset "expr_fused_pairs" begin
@time include("expr_fused_pairs.jl")
end
# TODO: add assertion test for Pairs that test against loops/if-else blocks
#! format: off
@safetestset "expr_code_lowered_single_expression" begin; @time include("expr_code_lowered_single_expression.jl"); end
@safetestset "expr_materialize_args" begin; @time include("expr_materialize_args.jl"); end
@safetestset "expr_fused_pairs" begin; @time include("expr_fused_pairs.jl"); end
@safetestset "expr_errors_and_edge_cases" begin; @time include("expr_errors_and_edge_cases.jl"); end

@safetestset "fused_shared_reads" begin
@time include("bm_fused_shared_reads.jl")
end
@safetestset "fused_shared_reads_writes" begin
@time include("bm_fused_shared_reads_writes.jl")
end
@safetestset "bm_fused_reads_vs_hard_coded" begin
@time include("bm_fused_reads_vs_hard_coded.jl")
end
@safetestset "fused_shared_reads" begin; @time include("bm_fused_shared_reads.jl"); end
@safetestset "fused_shared_reads_writes" begin; @time include("bm_fused_shared_reads_writes.jl"); end
@safetestset "bm_fused_reads_vs_hard_coded" begin; @time include("bm_fused_reads_vs_hard_coded.jl"); end
#! format: on

0 comments on commit 960cb7d

Please sign in to comment.