diff --git a/src/macro_utils.jl b/src/macro_utils.jl index 733b755..356194c 100644 --- a/src/macro_utils.jl +++ b/src/macro_utils.jl @@ -14,15 +14,27 @@ function _fused_pairs(expr::Expr) exprs_out = [] for _expr in expr.args # TODO: should we retain LineNumberNode? - if _expr isa Symbol # ???????? - return "" - end + # if _expr isa Symbol # ???????? + # error("???") + # return "" + # end _expr isa LineNumberNode && continue - # @assert _expr isa Expr + if _expr.head == :for + error("Loops are not allowed inside fused blocks") + end + if _expr.head == :if + error("If-statements are not allowed inside fused blocks") + end + if _expr.head == :call + error("Function calls are not allowed inside fused blocks") + end if _expr.head == :macrocall && _expr.args[1] == Symbol("@__dot__") se = code_lowered_single_expression(_expr) margs = materialize_args(se) push!(exprs_out, :(Pair($(margs[1]), $(margs[2])))) + else + @show _expr + error("Uncaught edge case") end end if length(exprs_out) == 1 diff --git a/test/bm_fused_reads_vs_hard_coded.jl b/test/bm_fused_reads_vs_hard_coded.jl index 173190e..5e72709 100644 --- a/test/bm_fused_reads_vs_hard_coded.jl +++ b/test/bm_fused_reads_vs_hard_coded.jl @@ -34,37 +34,16 @@ end function knl_multi_copyto_hard_coded!(X, Y, ::Val{nitems}) where {nitems} (; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X (; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y - gidx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x - if gidx < nitems - idx = gidx - y1[idx] = x1[idx] + x2[idx] + x3[idx] + x4[idx] - y2[idx] = x2[idx] + x3[idx] + x4[idx] + x5[idx] - y3[idx] = x3[idx] + x4[idx] + x5[idx] + x6[idx] - y4[idx] = x4[idx] + x5[idx] + x6[idx] + x7[idx] - y5[idx] = x5[idx] + x6[idx] + x7[idx] + x8[idx] - y6[idx] = x6[idx] + x7[idx] + x8[idx] + x9[idx] - y7[idx] = x7[idx] + x8[idx] + x9[idx] + x10[idx] - end - return nothing -end - -function knl_multi_copyto_hard_coded!(X, Y, ::Val{nitems}) where {nitems} - (; x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) = X - (; y1, y2, y3, y4, y5, y6, y7, y8, y9, y10) = Y - gidx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x - if gidx < nitems - @inbounds begin - idx = gidx - y1[idx] = x1[idx] + x6[idx] - y2[idx] = x2[idx] + x7[idx] - y3[idx] = x3[idx] + x8[idx] - y4[idx] = x4[idx] + x9[idx] - y5[idx] = x5[idx] + x10[idx] - y6[idx] = y1[idx] - y7[idx] = y2[idx] - y8[idx] = y3[idx] - y9[idx] = y4[idx] - y10[idx] = y5[idx] + idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + @inbounds begin + if idx ≤ nitems + y1[idx] = x1[idx] + x2[idx] + x3[idx] + x4[idx] + y2[idx] = x2[idx] + x3[idx] + x4[idx] + x5[idx] + y3[idx] = x3[idx] + x4[idx] + x5[idx] + x6[idx] + y4[idx] = x4[idx] + x5[idx] + x6[idx] + x7[idx] + y5[idx] = x5[idx] + x6[idx] + x7[idx] + x8[idx] + y6[idx] = x6[idx] + x7[idx] + x8[idx] + x9[idx] + y7[idx] = x7[idx] + x8[idx] + x9[idx] + x10[idx] end end return nothing @@ -108,6 +87,19 @@ function perf_kernel_fused!(X, Y) end end +test_kernel!(; + fused! = perf_kernel_fused!, + unfused! = perf_kernel_unfused!, + X, + Y, +) +test_kernel!(; + fused! = perf_kernel_hard_coded!, + unfused! = perf_kernel_unfused!, + X, + Y, +) + # Compile perf_kernel_unfused!(X, Y) perf_kernel_fused!(X, Y) diff --git a/test/expr_errors_and_edge_cases.jl b/test/expr_errors_and_edge_cases.jl new file mode 100644 index 0000000..8d8a543 --- /dev/null +++ b/test/expr_errors_and_edge_cases.jl @@ -0,0 +1,85 @@ +#= +using Revise; include(joinpath("test", "expr_errors_and_edge_cases.jl")) +=# +using Test +import MultiBroadcastFusion as MBF +@testset "loops" begin + # loops are not allowed, because + # code transformation occurs at macro + # expansion time, and we can't generally + # know how many times the loop will be + # executed at this time. + + # We could try to specialize on literal ranges, e.g., + # `for i in 1:10`, but that is likely an uncommon + # edge case. + + expr_in = quote + @. y1 = x1 + x2 + x3 + x4 + for i in 1:10 + @. y2 = x2 + x3 + x4 + x5 + end + @. y1 = x1 + x2 + x3 + x4 + end + @test_throws ErrorException("Loops are not allowed inside fused blocks") MBF.fused_pairs( + expr_in, + ) +end + +struct Foo end + +@testset "If-statements" begin + # If-statements are not allowed, because + # code transformation occurs at macro + # expansion time, and Bools, even types, + # are not known at this time. + + # We could specialize on literals, e.g., + # `if true`, but that is likely an uncommon + # edge case. + foo = Foo() + expr_in = quote + @. y1 = x1 + x2 + x3 + x4 + if foo isa Foo + @. y2 = x2 + x3 + x4 + x5 + end + @. y1 = x1 + x2 + x3 + x4 + end + @test_throws ErrorException( + "If-statements are not allowed inside fused blocks", + ) MBF.fused_pairs(expr_in) +end + +bar() = nothing +@testset "Function calls" begin + # Function calls are not allowed, because + # this could lead to subtle bugs (order of compute). + expr_in = quote + @. y1 = x1 + x2 + x3 + x4 + bar() + @. y1 = x1 + x2 + x3 + x4 + end + @test_throws ErrorException( + "Function calls are not allowed inside fused blocks", + ) MBF.fused_pairs(expr_in) +end + +@testset "Comments" begin + expr_in = quote + @. y1 = x1 + x2 + x3 + x4 + # Foo bar baz + # if i in 1:N + @. y2 = x2 + x3 + x4 + x5 + end + + expr_out = :(( + Pair(y1, Base.broadcasted(+, x1, x2, x3, x4)), + Pair(y2, Base.broadcasted(+, x2, x3, x4, x5)), + )) + @test MBF.fused_pairs(expr_in) == expr_out +end + +@testset "Empty" begin + expr_in = quote end + @test MBF.fused_pairs(expr_in) == :(()) +end diff --git a/test/runtests.jl b/test/runtests.jl index 0e2cc9f..8de2eec 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,23 +4,13 @@ using Revise; include(joinpath("test", "runtests.jl")) using Test using SafeTestsets -@safetestset "expr_code_lowered_single_expression" begin - @time include("expr_code_lowered_single_expression.jl") -end -@safetestset "expr_materialize_args" begin - @time include("expr_materialize_args.jl") -end -@safetestset "expr_fused_pairs" begin - @time include("expr_fused_pairs.jl") -end -# TODO: add assertion test for Pairs that test against loops/if-else blocks +#! format: off +@safetestset "expr_code_lowered_single_expression" begin; @time include("expr_code_lowered_single_expression.jl"); end +@safetestset "expr_materialize_args" begin; @time include("expr_materialize_args.jl"); end +@safetestset "expr_fused_pairs" begin; @time include("expr_fused_pairs.jl"); end +@safetestset "expr_errors_and_edge_cases" begin; @time include("expr_errors_and_edge_cases.jl"); end -@safetestset "fused_shared_reads" begin - @time include("bm_fused_shared_reads.jl") -end -@safetestset "fused_shared_reads_writes" begin - @time include("bm_fused_shared_reads_writes.jl") -end -@safetestset "bm_fused_reads_vs_hard_coded" begin - @time include("bm_fused_reads_vs_hard_coded.jl") -end +@safetestset "fused_shared_reads" begin; @time include("bm_fused_shared_reads.jl"); end +@safetestset "fused_shared_reads_writes" begin; @time include("bm_fused_shared_reads_writes.jl"); end +@safetestset "bm_fused_reads_vs_hard_coded" begin; @time include("bm_fused_reads_vs_hard_coded.jl"); end +#! format: on